Skip to content

Commit 3c80741

Browse files
authored
Merge pull request #79 from yuehhua/fg
Remove Reference from FeaturedGraph field
2 parents 0f2ac81 + 540288e commit 3c80741

File tree

2 files changed

+33
-37
lines changed

2 files changed

+33
-37
lines changed

src/graph/featuredgraphs.jl

Lines changed: 31 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,11 @@ References to graph or features are hold in this type.
1919
- `edge_feature`: edge features attached to graph.
2020
- `gloabl_feature`: gloabl graph features attached to graph.
2121
"""
22-
struct FeaturedGraph{T,S,R,Q} <: AbstractFeaturedGraph
23-
graph::Ref{T}
24-
nf::Ref{S}
25-
ef::Ref{R}
26-
gf::Ref{Q}
27-
28-
function FeaturedGraph(graph::T, nf::S, ef::R, gf::Q) where {T,S<:AbstractMatrix,R<:AbstractMatrix,Q<:AbstractVector}
29-
new{T,S,R,Q}(Ref(graph), Ref(nf), Ref(ef), Ref(gf))
30-
end
22+
mutable struct FeaturedGraph{T,S,R,Q} <: AbstractFeaturedGraph
23+
graph::T
24+
nf::S
25+
ef::R
26+
gf::Q
3127
end
3228

3329
FeaturedGraph() = FeaturedGraph(zeros(0,0), zeros(0,0), zeros(0,0), zeros(0))
@@ -42,97 +38,97 @@ FeaturedGraph(graph::T, nf::AbstractMatrix) where {T} = FeaturedGraph(graph, nf,
4238
Get referenced graph.
4339
"""
4440
graph(::NullGraph) = nothing
45-
graph(fg::FeaturedGraph) = fg.graph[]
41+
graph(fg::FeaturedGraph) = fg.graph
4642

4743
"""
4844
node_feature(::AbstractFeaturedGraph)
4945
5046
Get node feature attached to graph.
5147
"""
5248
node_feature(::NullGraph) = nothing
53-
node_feature(fg::FeaturedGraph) = fg.nf[]
49+
node_feature(fg::FeaturedGraph) = fg.nf
5450

5551
"""
5652
edge_feature(::AbstractFeaturedGraph)
5753
5854
Get edge feature attached to graph.
5955
"""
6056
edge_feature(::NullGraph) = nothing
61-
edge_feature(fg::FeaturedGraph) = fg.ef[]
57+
edge_feature(fg::FeaturedGraph) = fg.ef
6258

6359
"""
6460
global_feature(::AbstractFeaturedGraph)
6561
6662
Get global feature attached to graph.
6763
"""
6864
global_feature(::NullGraph) = nothing
69-
global_feature(fg::FeaturedGraph) = fg.gf[]
65+
global_feature(fg::FeaturedGraph) = fg.gf
7066

7167
has_graph(::NullGraph) = false
72-
has_graph(fg::FeaturedGraph) = fg.graph[] != zeros(0,0)
68+
has_graph(fg::FeaturedGraph) = fg.graph != zeros(0,0)
7369

7470
has_node_feature(::NullGraph) = false
75-
has_node_feature(fg::FeaturedGraph) = fg.nf[] != zeros(0,0)
71+
has_node_feature(fg::FeaturedGraph) = fg.nf != zeros(0,0)
7672

7773
has_edge_feature(::NullGraph) = false
78-
has_edge_feature(fg::FeaturedGraph) = fg.ef[] != zeros(0,0)
74+
has_edge_feature(fg::FeaturedGraph) = fg.ef != zeros(0,0)
7975

8076
has_global_feature(::NullGraph) = false
81-
has_global_feature(fg::FeaturedGraph) = fg.gf[] != zeros(0)
77+
has_global_feature(fg::FeaturedGraph) = fg.gf != zeros(0)
8278

8379
"""
8480
adjacency_list(::AbstractFeaturedGraph)
8581
8682
Get adjacency list of graph.
8783
"""
8884
adjacency_list(::NullGraph) = [zeros(0)]
89-
adjacency_list(fg::FeaturedGraph) = adjacency_list(fg.graph[])
85+
adjacency_list(fg::FeaturedGraph) = adjacency_list(fg.graph)
9086

9187
"""
9288
nv(::AbstractFeaturedGraph)
9389
9490
Get node number of graph.
9591
"""
9692
nv(::NullGraph) = 0
97-
nv(fg::FeaturedGraph) = nv(fg.graph[])
98-
nv(fg::FeaturedGraph{T}) where {T<:AbstractMatrix} = size(fg.graph[], 1)
99-
nv(fg::FeaturedGraph{T}) where {T<:AbstractVector} = length(fg.graph[])
93+
nv(fg::FeaturedGraph) = nv(fg.graph)
94+
nv(fg::FeaturedGraph{T}) where {T<:AbstractMatrix} = size(fg.graph, 1)
95+
nv(fg::FeaturedGraph{T}) where {T<:AbstractVector} = length(fg.graph)
10096

10197
"""
10298
ne(::AbstractFeaturedGraph)
10399
104100
Get edge number of graph.
105101
"""
106102
ne(::NullGraph) = 0
107-
ne(fg::FeaturedGraph) = ne(fg.graph[])
108-
ne(fg::FeaturedGraph{T}) where {T<:AbstractVector} = sum(map(length, fg.graph[]))÷2
103+
ne(fg::FeaturedGraph) = ne(fg.graph)
104+
ne(fg::FeaturedGraph{T}) where {T<:AbstractVector} = sum(map(length, fg.graph))÷2
109105

110106

111107

112108
## Linear algebra API for AbstractFeaturedGraph
113109

114-
adjacency_matrix(fg::FeaturedGraph, T::DataType=eltype(fg.graph[])) = adjacency_matrix(fg.graph[], T)
110+
adjacency_matrix(fg::FeaturedGraph, T::DataType=eltype(fg.graph)) = adjacency_matrix(fg.graph, T)
115111

116-
function degrees(fg::FeaturedGraph, T::DataType=eltype(fg.graph[]); dir::Symbol=:out)
117-
degrees(fg.graph[], T; dir=dir)
112+
function degrees(fg::FeaturedGraph, T::DataType=eltype(fg.graph); dir::Symbol=:out)
113+
degrees(fg.graph, T; dir=dir)
118114
end
119115

120-
function degree_matrix(fg::FeaturedGraph, T::DataType=eltype(fg.graph[]); dir::Symbol=:out)
121-
degree_matrix(fg.graph[], T; dir=dir)
116+
function degree_matrix(fg::FeaturedGraph, T::DataType=eltype(fg.graph); dir::Symbol=:out)
117+
degree_matrix(fg.graph, T; dir=dir)
122118
end
123119

124-
function inv_sqrt_degree_matrix(fg::FeaturedGraph, T::DataType=eltype(fg.graph[]); dir::Symbol=:out)
125-
inv_sqrt_degree_matrix(fg.graph[], T; dir=dir)
120+
function inv_sqrt_degree_matrix(fg::FeaturedGraph, T::DataType=eltype(fg.graph); dir::Symbol=:out)
121+
inv_sqrt_degree_matrix(fg.graph, T; dir=dir)
126122
end
127123

128-
function laplacian_matrix(fg::FeaturedGraph, T::DataType=eltype(fg.graph[]); dir::Symbol=:out)
129-
laplacian_matrix(fg.graph[], T; dir=dir)
124+
function laplacian_matrix(fg::FeaturedGraph, T::DataType=eltype(fg.graph); dir::Symbol=:out)
125+
laplacian_matrix(fg.graph, T; dir=dir)
130126
end
131127

132-
function normalized_laplacian(fg::FeaturedGraph, T::DataType=eltype(fg.graph[]); selfloop::Bool=false)
133-
normalized_laplacian(fg.graph[], T; selfloop=selfloop)
128+
function normalized_laplacian(fg::FeaturedGraph, T::DataType=eltype(fg.graph); selfloop::Bool=false)
129+
normalized_laplacian(fg.graph, T; selfloop=selfloop)
134130
end
135131

136132
function scaled_laplacian(fg::FeaturedGraph, T::DataType=eltype(fg.graph[]))
137-
scaled_laplacian(fg.graph[], T)
133+
scaled_laplacian(fg.graph, T)
138134
end

src/layers/conv.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ end
5353
function (g::GCNConv)(fg::FeaturedGraph)
5454
X = node_feature(fg)
5555
A = adjacency_matrix(fg)
56-
g.fg isa NullGraph || (g.fg.graph[] = A)
56+
g.fg isa NullGraph || (g.fg.graph = A)
5757
L = normalized_laplacian(A, eltype(X); selfloop=true)
5858
X_ = g.σ.(g.weight * X * L .+ g.bias)
5959
FeaturedGraph(A, X_)
@@ -136,7 +136,7 @@ end
136136
function (c::ChebConv)(fg::FeaturedGraph)
137137
@assert has_graph(fg) "A given FeaturedGraph must contain a graph."
138138
g = graph(fg)
139-
c.fg isa NullGraph || (c.fg.graph[] = g)
139+
c.fg isa NullGraph || (c.fg.graph = g)
140140
X = node_feature(fg)
141141
= scaled_laplacian(adjacency_matrix(fg))
142142
= convert(typeof(X), L̃)

0 commit comments

Comments
 (0)