diff --git a/examples/gat.jl b/examples/gat.jl index 64f7e3507..9493cae55 100644 --- a/examples/gat.jl +++ b/examples/gat.jl @@ -79,6 +79,7 @@ function train(; kws...) # build model model = Chain( WithGraph(fg, GATConv(args.input_dim=>args.hidden_dim, heads=args.heads)), + Dropout(0.6), WithGraph(fg, GATConv(args.hidden_dim*args.heads=>args.target_dim, heads=args.heads, concat=false)), ) |> device diff --git a/src/operation.jl b/src/operation.jl index 309c509f1..234701215 100644 --- a/src/operation.jl +++ b/src/operation.jl @@ -29,13 +29,12 @@ function incidence_matrix(xs::AbstractVector{T}, N) where {T} end function indexed_softmax(x::AbstractArray, xs, N; dims=1) - # memory pre-allocation approach leads to loss fluctuation but not drop anyway - # be aware of model loss while optimizing this code snippet - as = map(1:N) do i - idx = ntuple(j -> (j == dims) ? (xs .== i) : Colon(), ndims(x)) - NNlib.softmax(x[idx...]; dims) + y = copy(x) + for i in 1:N + idx = ntuple(j -> (j == dims) ? (xs .== i) : Colon(), ndims(y)) + NNlib.softmax!(view(y, idx...); dims) end - return cat(as...; dims) + return y end function ∇indexed_softmax(dy::AbstractArray{T}, y::AbstractArray{S}, xs, N; dims=1) where {T,S}