diff --git a/Project.toml b/Project.toml index 16e8e213..699dd109 100644 --- a/Project.toml +++ b/Project.toml @@ -3,12 +3,14 @@ uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" version = "0.12.1" [deps] +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] +ChainRulesCore = "1" julia = "1" [extras] diff --git a/src/FillArrays.jl b/src/FillArrays.jl index 9e6b0a7d..7b33922d 100644 --- a/src/FillArrays.jl +++ b/src/FillArrays.jl @@ -626,6 +626,7 @@ end include("fillalgebra.jl") include("fillbroadcast.jl") include("trues.jl") +include("chainrules.jl") ## # print diff --git a/src/chainrules.jl b/src/chainrules.jl new file mode 100644 index 00000000..01f35580 --- /dev/null +++ b/src/chainrules.jl @@ -0,0 +1,38 @@ +import ChainRulesCore: ProjectTo, NoTangent, Tangent + +""" + ProjectTo(::Fill) -> ProjectTo{Fill} + ProjectTo(::Ones) -> ProjectTo{NoTangent} + +Most FillArrays arrays store one number, and so their gradients under automatic +differentiation represent the variation of this one number. + +The exception is those like `Ones` and `Zeros` whose type fixes their value, +which have no graidient. +""" +ProjectTo(x::Fill{<:Number}) = ProjectTo{Fill}(; element = ProjectTo(getindex_value(x)), axes = axes(x)) + +ProjectTo(x::AbstractFill{Bool}) = ProjectTo{NoTangent}() # Bool is always regarded as categorical + +ProjectTo(x::Zeros) = ProjectTo{NoTangent}() +ProjectTo(x::Ones) = ProjectTo{NoTangent}() + +function (project::ProjectTo{Fill})(dx::AbstractArray) + for d in 1:max(ndims(dx), length(project.axes)) + size(dx, d) == length(get(project.axes, d, 1)) || throw(_projection_mismatch(axes_x, size(dx))) + end + Fill(mean(dx), project.axes) # Note that mean(dx::Fill) is optimised +end + +function (project::ProjectTo{Fill})(dx::Tangent{<:Fill}) + # This would need a definition for length(::NoTangent) to be safe: + # for d in 1:max(length(dx.axes), length(project.axes)) + # length(get(dx.axes, d, 1)) == length(get(project.axes, d, 1)) || throw(_projection_mismatch(dx.axes, size(dx))) + # end + Fill(dx.value / prod(length, project.axes), project.axes) +end + +function _projection_mismatch(axes_x::Tuple, size_dx::Tuple) + size_x = map(length, axes_x) + DimensionMismatch("variable with size(x) == $size_x cannot have a gradient with size(dx) == $size_dx") +end diff --git a/test/runtests.jl b/test/runtests.jl index a688ef54..0255bdf2 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,4 +1,7 @@ -using FillArrays, LinearAlgebra, SparseArrays, StaticArrays, Random, Base64, Test, Statistics + +using FillArrays, StaticArrays, ChainRulesCore, Base64 +using LinearAlgebra, SparseArrays, Random, Statistics, Test # standard libraries + import FillArrays: AbstractFill, RectDiagonal, SquareEye @testset "fill array constructors and convert" begin @@ -1323,3 +1326,14 @@ end @test cor(Fill(3, 4, 5)) ≈ cor(fill(3, 4, 5)) nans=true @test cor(Fill(3, 4, 5), dims=2) ≈ cor(fill(3, 4, 5), dims=2) nans=true end + +@testset "ChainRules integration" begin + x = Fill(1,2,3) + @test ProjectTo(x)(ones(2,3)) === Fill(1.0, 2, 3) + @test ProjectTo(x)(ones(2,3,1) .+ im) === Fill(1.0, 2, 3) + @test ProjectTo(x)(Fill(1+im, 2,3)) === Fill(1.0, 2, 3) + @test ProjectTo(x)(Tangent{typeof(x)}(; value=6)) === Fill(1.0, 2, 3) + + @test ProjectTo(Eye(3))(rand(3,3)) === NoTangent() + @test ProjectTo(Zeros(3))(rand(3)) === NoTangent() +end