diff --git a/Project.toml b/Project.toml index 959533d71..32e8a9b16 100644 --- a/Project.toml +++ b/Project.toml @@ -16,6 +16,7 @@ Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" NNlibCUDA = "a00861dc-f156-4864-bf3c-e6376f28a68d" +Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" @@ -27,13 +28,14 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" CUDA = "3" ChainRulesCore = "1.7" DataStructures = "0.18" -FillArrays = "0.12 - 0.13" +FillArrays = "0.13" Flux = "0.12 - 0.13" GraphMLDatasets = "0.1" GraphSignals = "0.4" Graphs = "1" -NNlib = "0.7 - 0.8" -NNlibCUDA = "0.1 - 0.2" +NNlib = "0.8" +NNlibCUDA = "0.2" +Optimisers = "0.2" Reexport = "1.1" Word2Vec = "0.5" Zygote = "0.6" diff --git a/src/GeometricFlux.jl b/src/GeometricFlux.jl index 81ada1e0f..a850515fa 100644 --- a/src/GeometricFlux.jl +++ b/src/GeometricFlux.jl @@ -16,6 +16,7 @@ using Flux: glorot_uniform, leakyrelu, GRUCell, @functor @reexport using GraphSignals using Graphs using NNlib, NNlibCUDA +using Optimisers using Zygote import Word2Vec: word2vec, wordvectors, get_vector diff --git a/src/layers/graphlayers.jl b/src/layers/graphlayers.jl index 1e33693b8..877ca2197 100644 --- a/src/layers/graphlayers.jl +++ b/src/layers/graphlayers.jl @@ -59,8 +59,8 @@ end Flux.trainable(l::WithGraph) = (l.layer, ) -function Flux.destructure(m::WithGraph) - p, re = Flux.destructure(m.layer) +function Optimisers.destructure(m::WithGraph) + p, re = destructure(m.layer) function re_withgraph(x) WithGraph(re(x), m.fg) end