Skip to content

Commit d8b3437

Browse files
authored
Merge pull request #57 from yuehhua/fix
Ignore gradient of generate_cluster
2 parents 0916223 + accce31 commit d8b3437

File tree

4 files changed

+11
-7
lines changed

4 files changed

+11
-7
lines changed

src/GeometricFlux.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,15 @@ module GeometricFlux
33
using Statistics: mean
44
using StaticArrays: StaticArray
55
using SparseArrays: SparseMatrixCSC
6-
using LinearAlgebra: I, issymmetric, diagm, eigmax, norm, Adjoint
6+
using LinearAlgebra: I, issymmetric, diagm, eigmax, norm, Adjoint, Diagonal
77

88
using Requires
99
using DataStructures: DefaultDict
1010
using Flux
1111
using Flux: glorot_uniform, leakyrelu, GRUCell
1212
using Flux: @functor
1313
using LightGraphs
14+
using Zygote
1415
using ZygoteRules
1516
using FillArrays: Fill
1617

src/layers/meta.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ function propagate(meta::T; adjl=adjlist(meta), kwargs...) where {T<:Meta}
3434
newE = update_edge(meta; gi=gi, kwargs...)
3535

3636
if haskey(kwargs, :naggr)
37-
= aggregate_neighbors(meta, kwargs[:naggr]; E=newE, cluster=generate_cluster(newE, gi))
37+
cluster = generate_cluster(newE, gi)
38+
= aggregate_neighbors(meta, kwargs[:naggr]; E=newE, cluster=cluster)
3839
kwargs = (kwargs..., Ē=Ē)
3940
end
4041

@@ -54,7 +55,7 @@ function propagate(meta::T; adjl=adjlist(meta), kwargs...) where {T<:Meta}
5455
(newE, newV, new_u)
5556
end
5657

57-
function generate_cluster(M::AbstractArray{T,N}, gi::GraphInfo) where {T,N}
58+
Zygote.@nograd function generate_cluster(M::AbstractArray{T,N}, gi::GraphInfo) where {T,N}
5859
cluster = similar(M, Int, gi.E)
5960
@inbounds for i = 1:gi.V
6061
j = gi.edge_idx[i]

src/layers/msgpass.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@ function propagate(mp::T; aggr::Symbol=:add, adjl=adjlist(mp), kwargs...) where
4141
M = update_edge(mp; gi=gi, kwargs...)
4242

4343
# aggregate function
44-
M = aggregate_neighbors(mp, aggr; M=M, cluster=generate_cluster(M, gi))
44+
cluster = generate_cluster(M, gi)
45+
M = aggregate_neighbors(mp, aggr; M=M, cluster=cluster)
4546

4647
# update function
4748
Y = update_vertex(mp; M=M, kwargs...)

src/operations/linalg.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
## Linear algebra API for adjacency matrix
2-
using LinearAlgebra
2+
3+
Zygote.@nograd issymmetric
34

45
function adjacency_matrix(adj::AbstractMatrix, T::DataType=eltype(adj))
56
m, n = size(adj)
@@ -124,9 +125,9 @@ function normalized_laplacian(adj::AbstractMatrix, T::DataType=eltype(adj); self
124125
end
125126

126127
@doc raw"""
127-
scaled_laplacian(adj::AbstractMatrix[, T::DataType])
128+
scaled_laplacian(adj::AbstractMatrix[, T::DataType])
128129
129-
Scaled Laplacien matrix of graph `g`,
130+
Scaled Laplacien matrix of graph `g`,
130131
defined as ``\hat{L} = \frac{2}{\lambda_{max}} L - I`` where ``L`` is the normalized Laplacian matrix.
131132
132133
# Arguments

0 commit comments

Comments
 (0)