@@ -472,10 +472,15 @@ end
472
472
473
473
474
474
"""
475
- function internalize (outer_graph, body_func, cond_func, var_list, input_overrides= Int[])
475
+ function internalize (outer_graph, body_func, cond_func, var_list, loop_name, stack_depth= 1 , input_overrides= Int[])
476
+ if stack_depth > 2
477
+ error (" internalize is recursing too much" )
478
+ return nothing
479
+ end
476
480
external_name = nothing
477
481
try
478
482
body_graph = Graph ()
483
+ body_graph. parent = ParentGraph (outer_graph, loop_name)
479
484
as_default (body_graph) do
480
485
inner_vars = []
481
486
for var in var_list
@@ -488,6 +493,7 @@ function internalize(outer_graph, body_func, cond_func, var_list, input_override
488
493
body_func (inner_vars... )
489
494
end
490
495
cond_graph = Graph ()
496
+ cond_graph. parent = ParentGraph (outer_graph, loop_name)
491
497
as_default (cond_graph) do
492
498
inner_vars = []
493
499
for var in var_list
@@ -508,6 +514,7 @@ function internalize(outer_graph, body_func, cond_func, var_list, input_override
508
514
end
509
515
if external_name != = nothing
510
516
external_tensor = get_tensor_by_name (outer_graph, external_name)
517
+ @show external_tensor
511
518
new_var_list = copy (var_list)
512
519
push! (new_var_list, external_tensor)
513
520
push! (input_overrides, length (new_var_list))
@@ -519,13 +526,14 @@ function internalize(outer_graph, body_func, cond_func, var_list, input_override
519
526
function new_cond_func (vars... )
520
527
cond_func ((vars[1 : end - 1 ]). .. )
521
528
end
522
- internalize (outer_graph, new_body_func, new_cond_func, new_var_list, input_overrides)
529
+ internalize (outer_graph, new_body_func, new_cond_func, new_var_list, loop_name, stack_depth + 1 , input_overrides)
523
530
else
524
531
return WhileGraph (body_func, cond_func, var_list, input_overrides)
525
532
end
526
533
end
527
534
528
535
function add_overrides (overrides, variables, inputs)
536
+
529
537
for override in overrides# internalized_graph.input_overrides
530
538
add_input_override (get_def_graph (), variables[override], inputs[override])
531
539
end
@@ -538,31 +546,39 @@ function while_loop(condition, body, variables; name=nothing, options=WhileLoopO
538
546
# TODO : fix underlying GC issue
539
547
n_variable_original = length (variables)
540
548
variables = Tensor .(variables)
541
- internalized_graph = internalize (get_def_graph (), body, condition, variables)
549
+ name === nothing && (name = get_name (" while" ))
550
+ name = String (name)
551
+ internalized_graph = internalize (get_def_graph (), body, condition, variables, name)
542
552
body = internalized_graph. body_func
543
553
condition = internalized_graph. cond_func
544
554
variables = internalized_graph. vars
555
+ identity_variables = identity .(variables)
545
556
gc_enable (false )
546
557
547
- name === nothing && (name = get_name (" while" ))
548
- name = String (name)
558
+
549
559
graph = get_def_graph ()
550
- params = new_while (graph, variables )
560
+ params = new_while (graph, identity_variables )
551
561
params. name = pointer (name)
552
562
n_inputs = length (variables)
553
563
cond_inputs_c = unsafe_wrap (Array, params. cond_inputs, n_inputs)
554
564
cond_inputs = Tensor .(cond_inputs_c)
555
565
local cond_output
556
- as_default (Graph (params. cond_graph)) do
566
+ cond_graph = Graph (params. cond_graph)
567
+ cond_graph. parent = ParentGraph (graph, name)
568
+ as_default (cond_graph) do
557
569
add_overrides (internalized_graph. input_overrides, variables, cond_inputs)
558
570
cond_output = condition (cond_inputs... )
559
571
end
560
572
params. cond_output = TF_Output (cond_output)
561
573
body_inputs_c = unsafe_wrap (Array, params. body_inputs, n_inputs)
562
574
body_inputs = Tensor .(body_inputs_c)
563
575
local body_outputs
564
- as_default (Graph (params. body_graph)) do
576
+ body_graph = Graph (params. body_graph)
577
+ body_graph. parent = ParentGraph (graph, name)
578
+
579
+ as_default (body_graph) do
565
580
add_overrides (internalized_graph. input_overrides, variables, body_inputs)
581
+
566
582
body_outputs = body (body_inputs... )
567
583
end
568
584
body_outputs_c = unsafe_wrap (Array, params. body_outputs, n_inputs)
0 commit comments