@@ -248,7 +248,7 @@ Example using shape_invariants:
248
248
shape_invariants=[i0.get_shape(), tensor_shape.TensorShape([None, 2])])
249
249
```
250
250
"""
251
- @op function while_loop (condition, body, variables; name= nothing , shape_invariants= nothing ,
251
+ @op function jl_while_loop (condition, body, variables; name= nothing , shape_invariants= nothing ,
252
252
parallel_iterations= 10 , back_prop= true , swap_memory= false )
253
253
g = Graph ()
254
254
def_graph = get_def_graph ()
@@ -445,7 +445,8 @@ function WhileLoopOptions(;parallel_iterations=10, back_prop=true, swap_memory=f
445
445
WhileLoopOptions (parallel_iterations, back_prop, swap_memory)
446
446
end
447
447
448
- function c_while (condition, body, variables; name= nothing , options= WhileLoopOptions ())
448
+ function while_loop (condition, body, variables; name= nothing , options= WhileLoopOptions ())
449
+ variables = Tensor .(variables)
449
450
name === nothing && (name = " while" )
450
451
name = String (name)
451
452
graph = get_def_graph ()
@@ -485,7 +486,7 @@ function create_while_context(graph, name, n_inputs; options=WhileLoopOptions())
485
486
loop_exit_names= String[])
486
487
context_matcher = Regex (" ^$(name) /" )
487
488
for op in get_operations (graph)
488
- @show op
489
+ # @show op
489
490
if ismatch (context_matcher, get_def (op). name)
490
491
def = get_def (op)
491
492
n_outputs = length (get_op_def (def. op). output_arg)
@@ -494,18 +495,19 @@ function create_while_context(graph, name, n_inputs; options=WhileLoopOptions())
494
495
end
495
496
end
496
497
end
497
- push! (ctx. values_def. values, " $(name) /merge0:1" )
498
- push! (ctx. values_def. values, " $(name) /switch0:1" )
498
+ # push!(ctx.values_def.values, "$(name)/merge0:1")
499
+ # push!(ctx.values_def.values, "$(name)/switch0:1")
499
500
set_field! (ctx, :pivot_for_pred_name , " $(name) /merge0:0" )
500
501
switch_name = " $(name) /switch0"
501
502
switch_op = get_node_by_name (switch_name) |> get |> get_def
502
503
# We assume the pivot tensor is the second input to the switch statement.
503
- # The first input is the result of the merge.
504
+ # The first input is the result of the merge.
504
505
cond_op = switch_op. input[2 ]
505
506
set_field! (ctx, :pivot_for_body_name , " $(switch_name) :0" )
506
507
set_field! (ctx, :pivot_name , " $(cond_op) :0" )
507
508
for i in 1 : n_inputs
508
509
push! (ctx. loop_exit_names, " $(name) /exit$(i- 1 ) :0" )
509
510
end
511
+ # dump(ctx)
510
512
return ctx
511
513
end
0 commit comments