Skip to content

AD in interpreter produces NaN values where it should not #2216

@FluxusMagna

Description

@FluxusMagna

This code results in incorrect output from the interpreter:

def identity_mat n = tabulate_2d n n (\i j -> f32.bool (i==j))

def Jacobi [n] (f: [n]f32 -> [n]f32) (x:[n]f32) : [n][n]f32
    = map (\i -> jvp f x i) (identity_mat n) |> transpose

def Hessian [n] (f: [n]f32 -> f32) (x:[n]f32) : [n][n]f32
    = Jacobi (\x -> vjp f x 1) x

def Hessian_test (x:[3]f32) = Hessian (\x -> x[1]**2+x[2]**2) x

To test this code we can run

> Hessian_test [0,0,0]

which should result in (and does in the compiled code)

[[0.0f32, 0.0f32, 0.0f32],
[0.0f32, 2.0f32, 0.0f32],
[0.0f32, 0.0f32, 2.0f32]]

But in the interpreter this produces.

[[0.0, f32.nan, f32.nan],
 [0.0, f32.nan, f32.nan],
 [0.0, f32.nan, f32.nan]]

I'm not sure why it fails like this, but I've observed similar issues with other nested AD, which eventually led me to this test.

edit: transpose jacobi matrix for correctness

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions