Skip to content

Type error after pass 'simplify' / Duplicate definition of function after internalisation #2218

@FluxusMagna

Description

@FluxusMagna

I first got this error,

After internalisation:
Duplicate definition of function entry_QR_decomp

then I reduced the code a bit and got this

Type error after pass 'simplify':
In function entry_map_QR_decomp
When checking function body
In expression of statement
  {AG_r_15103 : ({}, [nr_12256][n_12257][d<{(+) m n}>_13209]f32)}
Inside the loop body
In expression of statement
  {AG_15134 : ({}, [nr_12256][n_12257][d<{(+) m n}>_13209]f32)}
in body of case {true}
In expression of statement
  {AG_15135 : ({}, [nr_12256][n_12257][d<{(+) m n}>_13209]f32)}
In expression of statement
  {scatter_res_15227 : ({}, [n_12257][d<{(+) m n}>_13209]f32)}
1 indices given, but type of indexee has 2 dimension(s).

The reduced code:

let repl_segm_iota [n] (szs:[n]i64) : ([]i64, []i64) =
    if n == 0 then ([], []) else
    let offsets = scan (+) 0 szs
    let start_idx = map2 (-) offsets szs
    let sz = last offsets
    let repl = hist (+) 0 sz offsets (replicate n 1)
            |> scan (+) 0
    let segm = tabulate sz (\i -> i - start_idx[repl[i]])
    in (repl, segm)

def expand [n] 'a 'b (size: a -> i64) (get: a -> i64 -> b) (as:[n]a) : []b =
    let (is, js) = repl_segm_iota (map size as)
    in map2 (\i j -> get as[i] j) is js

local def givens (a:f32) (b:f32) =
    let sign x = let s = f32.sgn x in if s == 0 then 1 else s
    let abs = f32.abs
    in   if b == 0 then
        (sign a, 0, abs a)
    else if a == 0 then
        (0, -sign b, abs b)
    else if abs a > abs b then
        let t = b / a
        let u = sign a * f32.sqrt (1+t*t)
        let ui = f32.recip u
        in (ui, -ui * t, a * u)
    else
        let t = a / b
        let u = sign a * f32.sqrt (1+t*t)
        let ui = f32.recip u
        in (t * ui, -ui, b * u)

def identity_mat n = tabulate_2d n n (\i j -> f32.bool (i==j))
def diagonal_mat [n] (d:[n]f32) = tabulate_2d n n (\i j -> if i==j then d[i] else 0)

--row indices must be unique
local def givens_par_rot [n][m][k] (indices:[k](i64,i64,i64))  (A:*[n][m]f32) (G:*[n][n]f32) =
    let new_rows (col,i,j) =
        let a = A[i,col]
        let b = A[j,col]
        let (c,s,_) = givens a b
        let (a_i, a_j) = map2 (\a b -> (c*a - s*b, s*a + c*b)) A[i] A[j] |> unzip
        let (g_i, g_j) = map2 (\a b -> (c*a - s*b, s*a + c*b)) G[i] G[j] |> unzip
        in ((i,j),(a_i,a_j),(g_i,g_j))
    let (ijs, a_ijs, g_ijs) = map new_rows indices |> unzip3
    let (is, js) = unzip ijs
    let (a_is, a_js) = unzip a_ijs
    let (g_is, g_js) = unzip g_ijs
    let A = scatter A (is++js) (a_is++a_js)
    let G = scatter G (is++js) (g_is++g_js)
    in (A, G)

local def givens_par_rot_merged [n][m][k] (indices:[k](i64,i64,i64))  (AG:*[n][m]f32) =
    let new_rows (col,i,j) =
        let a = AG[i,col]
        let b = AG[j,col]
        let (c,s,_) = givens a b
        let (ag_i, ag_j) = map2 (\a b -> (c*a - s*b, s*a + c*b)) AG[i] AG[j] |> unzip
        in ((i,j),(ag_i,ag_j))
    let (ijs, ag_ijs) = map new_rows indices |> unzip
    let (is, js) = unzip ijs
    let (ag_is, ag_js) = unzip ag_ijs
    in scatter AG (is++js) (ag_is++ag_js)

local def givens_rot (col:i64, i:i64, j:i64) (AG:*[][]f32) =
    let a = AG[i,col]
    let b = AG[j,col]
    let (c,s,_) = givens a b
    let (ag_i, ag_j) = map2 (\a b -> (c*a - s*b, s*a + c*b)) AG[i] AG[j] |> unzip
    in AG with [i] = ag_i with [j] = ag_j

local def all_work_indices (n:i64) (m:i64) =
    let block i prog prog' =
        (i, prog, i64.max 0 (i64.min (n-i) (if i == 0 then n else prog')))
    let size (_, a, b) = b-a >> 1
    let get (c, a, b) k =
        let j = n-1-a-k
        let i = j - size (c, a, b)
        in (c,i,j)
    let (iter, _) = loop (iter, progress) = (0, replicate m 0)
        while not (all id (tabulate m (\i -> progress[i] >= n-i-1))) do
                let blockrow = map3 block
                        (iota m) (progress) (rotate (-1) progress)
                let sizes = map size blockrow
                in (iter + 1, map2 (+) progress sizes)
    let (blocks, iter_sizes, _) =
        loop (blocks, iter_sizes, progress) = (replicate iter (replicate m (0,0,0)), replicate iter 0, replicate m 0)
            for j < iter do
                let blockrow = map3 block
                    (iota m) (progress) (rotate (-1) progress)
                let sizes = map size blockrow
                in (blocks with [j] = blockrow, iter_sizes with [j] = i64.sum sizes, map2 (+) progress sizes)
    let indices = expand size get (flatten blocks)
    let sections = map2 (\a b -> (b-a, b)) iter_sizes (scan (+) 0 iter_sizes)
    in (indices, sections)


local def apply_work_indices_merged [n][m]
    (indices:[](i64,i64,i64), sections:[](i64,i64))
    (A:*[n][m]f32) (G:*[n][n]f32) =
    let AG = loop AG = map2 (++) A G
        for s in sections do
            givens_par_rot_merged indices[s.0:s.1] AG
    let (R, G) = map split AG |> unzip
    let Q = transpose G
    in (Q, R)

local def apply_work_indices_merged_copy [n][m]
    (indices:[](i64,i64,i64), sections:[](i64,i64))
    (A:[n][m]f32) (G:[n][n]f32) =
    let AG = loop AG = map2 (++) (copy A) (copy G)
        for s in sections do
            givens_par_rot_merged indices[s.0:s.1] AG
    let (R, G) = map split AG |> unzip
    let Q = transpose G
    in (Q, R)

local def apply_work_indices [n][m]
    (indices:[](i64,i64,i64), sections:[](i64,i64))
    (A:*[n][m]f32) (G:*[n][n]f32) =
    let (R, G) = loop (A, G) = (A, G)
        for s in sections do
            givens_par_rot indices[s.0:s.1] A G
    let Q = transpose G
    in (Q, R)

entry QR_decomp [n][m] (A:[n][m]f32) =
    apply_work_indices_merged (all_work_indices n m) (copy A) (identity_mat n)

entry map_QR_decomp [nr][n][m] (As:[nr][n][m]f32) =
    let indices = all_work_indices n m
    in map2 (apply_work_indices_merged_copy indices) (copy As) (replicate nr (identity_mat n))

adding this additional code results in the first error again.

def dot_prod a b = f32.sum (map2 (*) a b)
def outer_prod [n][m] (a:[n]f32) (b:[m]f32) = map (\a -> map (\b -> a*b) b) a

def mat_mult [n][m][l] (A:[n][m]f32) (B:[m][l]f32) : [n][l]f32 =
    tabulate_2d n l (\i j -> dot_prod A[i,:] B[:,j])

def matvec_mult [n][m] (A:[n][m]f32) (b:[m]f32) : [n]f32 =
    tabulate n (\i -> dot_prod A[i,:] b)


entry eigen_old [n] (A:[n][n]f32) =
    let iteration A =
        let (Q, R) = QR_decomp A
        in (mat_mult R Q)
    let extraction A = tabulate n (\i -> A[i,i])
    let condition A =
        let (diag, lower) = tabulate_2d n n (\i j ->
            let a = f32.abs A[i,j]
            let lower = i > j
            let diag = i == j
            in (f32.bool diag * a, f32.bool lower * a)) |> flatten |> unzip
        in f32.sum lower < f32.sum diag * f32.epsilon * f32.i64 n
    in iterate_until condition iteration A |> extraction

This all happened when I added map_QR_decomp .

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions