-
Notifications
You must be signed in to change notification settings - Fork 191
Closed
Description
I first got this error,
After internalisation:
Duplicate definition of function entry_QR_decomp
then I reduced the code a bit and got this
Type error after pass 'simplify':
In function entry_map_QR_decomp
When checking function body
In expression of statement
{AG_r_15103 : ({}, [nr_12256][n_12257][d<{(+) m n}>_13209]f32)}
Inside the loop body
In expression of statement
{AG_15134 : ({}, [nr_12256][n_12257][d<{(+) m n}>_13209]f32)}
in body of case {true}
In expression of statement
{AG_15135 : ({}, [nr_12256][n_12257][d<{(+) m n}>_13209]f32)}
In expression of statement
{scatter_res_15227 : ({}, [n_12257][d<{(+) m n}>_13209]f32)}
1 indices given, but type of indexee has 2 dimension(s).
The reduced code:
let repl_segm_iota [n] (szs:[n]i64) : ([]i64, []i64) =
if n == 0 then ([], []) else
let offsets = scan (+) 0 szs
let start_idx = map2 (-) offsets szs
let sz = last offsets
let repl = hist (+) 0 sz offsets (replicate n 1)
|> scan (+) 0
let segm = tabulate sz (\i -> i - start_idx[repl[i]])
in (repl, segm)
def expand [n] 'a 'b (size: a -> i64) (get: a -> i64 -> b) (as:[n]a) : []b =
let (is, js) = repl_segm_iota (map size as)
in map2 (\i j -> get as[i] j) is js
local def givens (a:f32) (b:f32) =
let sign x = let s = f32.sgn x in if s == 0 then 1 else s
let abs = f32.abs
in if b == 0 then
(sign a, 0, abs a)
else if a == 0 then
(0, -sign b, abs b)
else if abs a > abs b then
let t = b / a
let u = sign a * f32.sqrt (1+t*t)
let ui = f32.recip u
in (ui, -ui * t, a * u)
else
let t = a / b
let u = sign a * f32.sqrt (1+t*t)
let ui = f32.recip u
in (t * ui, -ui, b * u)
def identity_mat n = tabulate_2d n n (\i j -> f32.bool (i==j))
def diagonal_mat [n] (d:[n]f32) = tabulate_2d n n (\i j -> if i==j then d[i] else 0)
--row indices must be unique
local def givens_par_rot [n][m][k] (indices:[k](i64,i64,i64)) (A:*[n][m]f32) (G:*[n][n]f32) =
let new_rows (col,i,j) =
let a = A[i,col]
let b = A[j,col]
let (c,s,_) = givens a b
let (a_i, a_j) = map2 (\a b -> (c*a - s*b, s*a + c*b)) A[i] A[j] |> unzip
let (g_i, g_j) = map2 (\a b -> (c*a - s*b, s*a + c*b)) G[i] G[j] |> unzip
in ((i,j),(a_i,a_j),(g_i,g_j))
let (ijs, a_ijs, g_ijs) = map new_rows indices |> unzip3
let (is, js) = unzip ijs
let (a_is, a_js) = unzip a_ijs
let (g_is, g_js) = unzip g_ijs
let A = scatter A (is++js) (a_is++a_js)
let G = scatter G (is++js) (g_is++g_js)
in (A, G)
local def givens_par_rot_merged [n][m][k] (indices:[k](i64,i64,i64)) (AG:*[n][m]f32) =
let new_rows (col,i,j) =
let a = AG[i,col]
let b = AG[j,col]
let (c,s,_) = givens a b
let (ag_i, ag_j) = map2 (\a b -> (c*a - s*b, s*a + c*b)) AG[i] AG[j] |> unzip
in ((i,j),(ag_i,ag_j))
let (ijs, ag_ijs) = map new_rows indices |> unzip
let (is, js) = unzip ijs
let (ag_is, ag_js) = unzip ag_ijs
in scatter AG (is++js) (ag_is++ag_js)
local def givens_rot (col:i64, i:i64, j:i64) (AG:*[][]f32) =
let a = AG[i,col]
let b = AG[j,col]
let (c,s,_) = givens a b
let (ag_i, ag_j) = map2 (\a b -> (c*a - s*b, s*a + c*b)) AG[i] AG[j] |> unzip
in AG with [i] = ag_i with [j] = ag_j
local def all_work_indices (n:i64) (m:i64) =
let block i prog prog' =
(i, prog, i64.max 0 (i64.min (n-i) (if i == 0 then n else prog')))
let size (_, a, b) = b-a >> 1
let get (c, a, b) k =
let j = n-1-a-k
let i = j - size (c, a, b)
in (c,i,j)
let (iter, _) = loop (iter, progress) = (0, replicate m 0)
while not (all id (tabulate m (\i -> progress[i] >= n-i-1))) do
let blockrow = map3 block
(iota m) (progress) (rotate (-1) progress)
let sizes = map size blockrow
in (iter + 1, map2 (+) progress sizes)
let (blocks, iter_sizes, _) =
loop (blocks, iter_sizes, progress) = (replicate iter (replicate m (0,0,0)), replicate iter 0, replicate m 0)
for j < iter do
let blockrow = map3 block
(iota m) (progress) (rotate (-1) progress)
let sizes = map size blockrow
in (blocks with [j] = blockrow, iter_sizes with [j] = i64.sum sizes, map2 (+) progress sizes)
let indices = expand size get (flatten blocks)
let sections = map2 (\a b -> (b-a, b)) iter_sizes (scan (+) 0 iter_sizes)
in (indices, sections)
local def apply_work_indices_merged [n][m]
(indices:[](i64,i64,i64), sections:[](i64,i64))
(A:*[n][m]f32) (G:*[n][n]f32) =
let AG = loop AG = map2 (++) A G
for s in sections do
givens_par_rot_merged indices[s.0:s.1] AG
let (R, G) = map split AG |> unzip
let Q = transpose G
in (Q, R)
local def apply_work_indices_merged_copy [n][m]
(indices:[](i64,i64,i64), sections:[](i64,i64))
(A:[n][m]f32) (G:[n][n]f32) =
let AG = loop AG = map2 (++) (copy A) (copy G)
for s in sections do
givens_par_rot_merged indices[s.0:s.1] AG
let (R, G) = map split AG |> unzip
let Q = transpose G
in (Q, R)
local def apply_work_indices [n][m]
(indices:[](i64,i64,i64), sections:[](i64,i64))
(A:*[n][m]f32) (G:*[n][n]f32) =
let (R, G) = loop (A, G) = (A, G)
for s in sections do
givens_par_rot indices[s.0:s.1] A G
let Q = transpose G
in (Q, R)
entry QR_decomp [n][m] (A:[n][m]f32) =
apply_work_indices_merged (all_work_indices n m) (copy A) (identity_mat n)
entry map_QR_decomp [nr][n][m] (As:[nr][n][m]f32) =
let indices = all_work_indices n m
in map2 (apply_work_indices_merged_copy indices) (copy As) (replicate nr (identity_mat n))
adding this additional code results in the first error again.
def dot_prod a b = f32.sum (map2 (*) a b)
def outer_prod [n][m] (a:[n]f32) (b:[m]f32) = map (\a -> map (\b -> a*b) b) a
def mat_mult [n][m][l] (A:[n][m]f32) (B:[m][l]f32) : [n][l]f32 =
tabulate_2d n l (\i j -> dot_prod A[i,:] B[:,j])
def matvec_mult [n][m] (A:[n][m]f32) (b:[m]f32) : [n]f32 =
tabulate n (\i -> dot_prod A[i,:] b)
entry eigen_old [n] (A:[n][n]f32) =
let iteration A =
let (Q, R) = QR_decomp A
in (mat_mult R Q)
let extraction A = tabulate n (\i -> A[i,i])
let condition A =
let (diag, lower) = tabulate_2d n n (\i j ->
let a = f32.abs A[i,j]
let lower = i > j
let diag = i == j
in (f32.bool diag * a, f32.bool lower * a)) |> flatten |> unzip
in f32.sum lower < f32.sum diag * f32.epsilon * f32.i64 n
in iterate_until condition iteration A |> extraction
This all happened when I added map_QR_decomp .