With the new caching API introduced in https://github.com/LuxDL/Lux.jl/pull/640 this shouldn't be too hard