Problem
Polaris.Updates.apply_updates/3 crashes when called with 2 arguments (relying on the state \\ nil default) under Nx 0.11.
The error:
** (Protocol.UndefinedError) protocol Nx.LazyContainer not implemented for Atom.
Got value:
nil
(nx 0.11.0) lib/nx/lazy_container.ex:91: Nx.LazyContainer.Atom.traverse/3
(nx 0.11.0) lib/nx/defn/compiler.ex:852: anonymous fn/2 in Nx.Defn.Compiler.to_lazy_params/2
Cause
apply_updates is defined as:
defn apply_updates(params, updates, state \\ nil) do
The merge_state/2 deftransformp correctly handles nil at the transform level (returns params unchanged). But Nx 0.11's JIT compiler calls Nx.Defn.Compiler.to_lazy_params/2 on all arguments before the function body runs. nil is an Atom, and Atom does not implement the Nx.LazyContainer protocol, so it crashes before merge_state ever sees it.
This worked in earlier Nx versions where the default argument was not passed through to_lazy_params.
Reproduction
Mix.install([{:polaris, "~> 0.1"}, {:nx, "~> 0.11"}])
{init_fn, update_fn} = Polaris.Optimizers.adam(learning_rate: 0.01)
params = %{"w" => Nx.iota({3, 3}, type: :f32)}
state = init_fn.(params)
# Simulate a gradient step
grads = %{"w" => Nx.broadcast(0.1, {3, 3})}
{updates, _new_state} = update_fn.(grads, state, params)
# This crashes:
Polaris.Updates.apply_updates(params, updates)
# This works:
Polaris.Updates.apply_updates(params, updates, %{})
Suggested fix
Change the default from nil to %{}:
defn apply_updates(params, updates, state \\ %{}) do
The merge_state deftransformp already handles an empty map correctly via the merge_inner path, which is a no-op when state has no matching keys.
Versions
- Polaris 0.1.0
- Nx 0.11.0
- Elixir 1.20.0-rc.3
- Erlang/OTP 27.3
Problem
Polaris.Updates.apply_updates/3crashes when called with 2 arguments (relying on thestate \\ nildefault) under Nx 0.11.The error:
Cause
apply_updatesis defined as:The
merge_state/2deftransformp correctly handlesnilat the transform level (returnsparamsunchanged). But Nx 0.11's JIT compiler callsNx.Defn.Compiler.to_lazy_params/2on all arguments before the function body runs.nilis an Atom, and Atom does not implement theNx.LazyContainerprotocol, so it crashes beforemerge_stateever sees it.This worked in earlier Nx versions where the default argument was not passed through
to_lazy_params.Reproduction
Suggested fix
Change the default from
nilto%{}:The
merge_statedeftransformp already handles an empty map correctly via themerge_innerpath, which is a no-op when state has no matching keys.Versions