|
| 1 | +import argparse |
| 2 | +import os.path as osp |
| 3 | + |
| 4 | +import torch |
| 5 | +import torch.nn.functional as F |
| 6 | + |
| 7 | +import torch_geometric.transforms as T |
| 8 | +from torch_geometric.datasets import Planetoid |
| 9 | +from torch_geometric.nn import GCNConv |
| 10 | +from torch_geometric.utils import negative_sampling, train_test_split_edges |
| 11 | + |
| 12 | + |
| 13 | +class GCNEncoder(torch.nn.Module): |
| 14 | + def __init__(self, in_channels, hidden_channels, out_channels): |
| 15 | + super().__init__() |
| 16 | + self.conv1 = GCNConv(in_channels, hidden_channels) |
| 17 | + self.conv2 = GCNConv(hidden_channels, out_channels) |
| 18 | + |
| 19 | + def forward(self, x, edge_index): |
| 20 | + x = self.conv1(x, edge_index).relu() |
| 21 | + return self.conv2(x, edge_index) |
| 22 | + |
| 23 | + |
| 24 | +class LinkPredictor(torch.nn.Module): |
| 25 | + def __init__(self, in_channels, hidden_channels): |
| 26 | + super().__init__() |
| 27 | + self.lin1 = torch.nn.Linear(in_channels * 2, hidden_channels) |
| 28 | + self.lin2 = torch.nn.Linear(hidden_channels, 1) |
| 29 | + |
| 30 | + def forward(self, z_i, z_j): |
| 31 | + x = torch.cat([z_i, z_j], dim=1) |
| 32 | + x = self.lin1(x).relu() |
| 33 | + x = self.lin2(x) |
| 34 | + return x.view(-1) |
| 35 | + |
| 36 | + |
| 37 | +class ARLinkPredictor(torch.nn.Module): |
| 38 | + def __init__(self, in_channels): |
| 39 | + super().__init__() |
| 40 | + # Split dimensions between attract and repel |
| 41 | + self.attract_dim = in_channels // 2 |
| 42 | + self.repel_dim = in_channels - self.attract_dim |
| 43 | + |
| 44 | + def forward(self, z_i, z_j): |
| 45 | + # Split into attract and repel parts |
| 46 | + z_i_attr = z_i[:, :self.attract_dim] |
| 47 | + z_i_repel = z_i[:, self.attract_dim:] |
| 48 | + |
| 49 | + z_j_attr = z_j[:, :self.attract_dim] |
| 50 | + z_j_repel = z_j[:, self.attract_dim:] |
| 51 | + |
| 52 | + # Calculate AR score |
| 53 | + attract_score = (z_i_attr * z_j_attr).sum(dim=1) |
| 54 | + repel_score = (z_i_repel * z_j_repel).sum(dim=1) |
| 55 | + |
| 56 | + return attract_score - repel_score |
| 57 | + |
| 58 | + |
| 59 | +def train(encoder, predictor, data, optimizer): |
| 60 | + encoder.train() |
| 61 | + predictor.train() |
| 62 | + |
| 63 | + # Forward pass and calculate loss |
| 64 | + optimizer.zero_grad() |
| 65 | + z = encoder(data.x, data.train_pos_edge_index) |
| 66 | + |
| 67 | + # Positive edges |
| 68 | + pos_out = predictor(z[data.train_pos_edge_index[0]], |
| 69 | + z[data.train_pos_edge_index[1]]) |
| 70 | + |
| 71 | + # Sample and predict on negative edges |
| 72 | + neg_edge_index = negative_sampling( |
| 73 | + edge_index=data.train_pos_edge_index, |
| 74 | + num_nodes=data.num_nodes, |
| 75 | + num_neg_samples=data.train_pos_edge_index.size(1), |
| 76 | + ) |
| 77 | + neg_out = predictor(z[neg_edge_index[0]], z[neg_edge_index[1]]) |
| 78 | + |
| 79 | + # Calculate loss |
| 80 | + pos_loss = F.binary_cross_entropy_with_logits(pos_out, |
| 81 | + torch.ones_like(pos_out)) |
| 82 | + neg_loss = F.binary_cross_entropy_with_logits(neg_out, |
| 83 | + torch.zeros_like(neg_out)) |
| 84 | + loss = pos_loss + neg_loss |
| 85 | + |
| 86 | + loss.backward() |
| 87 | + optimizer.step() |
| 88 | + |
| 89 | + return loss.item() |
| 90 | + |
| 91 | + |
| 92 | +@torch.no_grad() |
| 93 | +def test(encoder, predictor, data): |
| 94 | + encoder.eval() |
| 95 | + predictor.eval() |
| 96 | + |
| 97 | + z = encoder(data.x, data.train_pos_edge_index) |
| 98 | + |
| 99 | + pos_val_out = predictor(z[data.val_pos_edge_index[0]], |
| 100 | + z[data.val_pos_edge_index[1]]) |
| 101 | + neg_val_out = predictor(z[data.val_neg_edge_index[0]], |
| 102 | + z[data.val_neg_edge_index[1]]) |
| 103 | + |
| 104 | + pos_test_out = predictor(z[data.test_pos_edge_index[0]], |
| 105 | + z[data.test_pos_edge_index[1]]) |
| 106 | + neg_test_out = predictor(z[data.test_neg_edge_index[0]], |
| 107 | + z[data.test_neg_edge_index[1]]) |
| 108 | + |
| 109 | + val_auc = compute_auc(pos_val_out, neg_val_out) |
| 110 | + test_auc = compute_auc(pos_test_out, neg_test_out) |
| 111 | + |
| 112 | + return val_auc, test_auc |
| 113 | + |
| 114 | + |
| 115 | +def compute_auc(pos_out, neg_out): |
| 116 | + pos_out = torch.sigmoid(pos_out).cpu().numpy() |
| 117 | + neg_out = torch.sigmoid(neg_out).cpu().numpy() |
| 118 | + |
| 119 | + # Simple AUC calculation |
| 120 | + from sklearn.metrics import roc_auc_score |
| 121 | + y_true = torch.cat( |
| 122 | + [torch.ones(pos_out.shape[0]), |
| 123 | + torch.zeros(neg_out.shape[0])]) |
| 124 | + y_score = torch.cat([torch.tensor(pos_out), torch.tensor(neg_out)]) |
| 125 | + |
| 126 | + return roc_auc_score(y_true, y_score) |
| 127 | + |
| 128 | + |
| 129 | +def main(): |
| 130 | + parser = argparse.ArgumentParser() |
| 131 | + parser.add_argument('--dataset', type=str, default='Cora', |
| 132 | + choices=['Cora', 'CiteSeer', 'PubMed']) |
| 133 | + parser.add_argument('--hidden_channels', type=int, default=128) |
| 134 | + parser.add_argument('--out_channels', type=int, default=64) |
| 135 | + parser.add_argument('--epochs', type=int, default=200) |
| 136 | + parser.add_argument('--use_ar', action='store_true', |
| 137 | + help='Use Attract-Repel embeddings') |
| 138 | + parser.add_argument('--lr', type=float, default=0.01) |
| 139 | + args = parser.parse_args() |
| 140 | + |
| 141 | + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| 142 | + |
| 143 | + # Load dataset |
| 144 | + transform = T.Compose([ |
| 145 | + T.NormalizeFeatures(), |
| 146 | + T.ToDevice(device), |
| 147 | + ]) |
| 148 | + |
| 149 | + path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', |
| 150 | + args.dataset) |
| 151 | + dataset = Planetoid(path, args.dataset, transform=transform) |
| 152 | + data = dataset[0] |
| 153 | + |
| 154 | + # Process data for link prediction |
| 155 | + data = train_test_split_edges(data) |
| 156 | + |
| 157 | + # Initialize encoder |
| 158 | + encoder = GCNEncoder( |
| 159 | + in_channels=dataset.num_features, |
| 160 | + hidden_channels=args.hidden_channels, |
| 161 | + out_channels=args.out_channels, |
| 162 | + ).to(device) |
| 163 | + |
| 164 | + # Choose predictor based on args |
| 165 | + if args.use_ar: |
| 166 | + predictor = ARLinkPredictor(in_channels=args.out_channels).to(device) |
| 167 | + print(f"Running link prediction on {args.dataset}" |
| 168 | + f"with Attract-Repel embeddings") |
| 169 | + else: |
| 170 | + predictor = LinkPredictor( |
| 171 | + in_channels=args.out_channels, |
| 172 | + hidden_channels=args.hidden_channels).to(device) |
| 173 | + print(f"Running link prediction on {args.dataset}" |
| 174 | + f"with Traditional embeddings") |
| 175 | + |
| 176 | + optimizer = torch.optim.Adam( |
| 177 | + list(encoder.parameters()) + list(predictor.parameters()), lr=args.lr) |
| 178 | + |
| 179 | + best_val_auc = 0 |
| 180 | + final_test_auc = 0 |
| 181 | + |
| 182 | + for epoch in range(1, args.epochs + 1): |
| 183 | + loss = train(encoder, predictor, data, optimizer) |
| 184 | + val_auc, test_auc = test(encoder, predictor, data) |
| 185 | + |
| 186 | + if val_auc > best_val_auc: |
| 187 | + best_val_auc = val_auc |
| 188 | + final_test_auc = test_auc |
| 189 | + |
| 190 | + if epoch % 10 == 0: |
| 191 | + print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, ' |
| 192 | + f'Val AUC: {val_auc:.4f}, ' |
| 193 | + f'Test AUC: {test_auc:.4f}') |
| 194 | + |
| 195 | + print(f'Final results - Val AUC: {best_val_auc:.4f}, ' |
| 196 | + f'Test AUC: {final_test_auc:.4f}') |
| 197 | + |
| 198 | + # Calculate R-fraction if using AR |
| 199 | + if args.use_ar: |
| 200 | + with torch.no_grad(): |
| 201 | + z = encoder(data.x, data.train_pos_edge_index) |
| 202 | + attr_dim = args.out_channels // 2 |
| 203 | + |
| 204 | + z_attr = z[:, :attr_dim] |
| 205 | + z_repel = z[:, attr_dim:] |
| 206 | + |
| 207 | + attract_norm_squared = torch.sum(z_attr**2) |
| 208 | + repel_norm_squared = torch.sum(z_repel**2) |
| 209 | + |
| 210 | + r_fraction = repel_norm_squared / (attract_norm_squared + |
| 211 | + repel_norm_squared) |
| 212 | + print(f"R-fraction: {r_fraction.item():.4f}") |
| 213 | + |
| 214 | + |
| 215 | +if __name__ == '__main__': |
| 216 | + main() |
0 commit comments