Skip to content

Commit 5ca4427

Browse files
committed
Maybe have variables working?
1 parent 82bfa05 commit 5ca4427

File tree

3 files changed

+51
-16
lines changed

3 files changed

+51
-16
lines changed

src/core.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,11 @@ TensorShape(t::TensorShape) = copy(t)
240240

241241
function get_shape end
242242

243+
struct ParentGraph
244+
graph::Any # Really, Graph
245+
prefix::String
246+
end
247+
243248
"""
244249
A TensorFlow computation graph
245250
"""
@@ -250,6 +255,7 @@ mutable struct Graph
250255
name_idx::Dict{String, Int}
251256
op_context::OperationContext
252257
input_override::Dict
258+
parent::Union{Void, ParentGraph}
253259

254260
function Graph(ptr::Ptr)
255261
collections = Dict{Symbol, Any}()
@@ -258,7 +264,7 @@ mutable struct Graph
258264
collections[:Summaries] = []
259265
collections[:QueueRunners] = []
260266
collections[:while_context] = []
261-
self = new(ptr, collections, Dict{String, TensorShape}(), Dict{String, Int}(), OperationContext(Vector{Operation}[], String[], tensorflow.WhileContextDef[], Device[], Ref(false)), Dict())
267+
self = new(ptr, collections, Dict{String, TensorShape}(), Dict{String, Int}(), OperationContext(Vector{Operation}[], String[], tensorflow.WhileContextDef[], Device[], Ref(false)), Dict(), nothing)
262268
self
263269
end
264270

src/ops/control_flow.jl

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -472,10 +472,15 @@ end
472472
473473
474474
"""
475-
function internalize(outer_graph, body_func, cond_func, var_list, input_overrides=Int[])
475+
function internalize(outer_graph, body_func, cond_func, var_list, loop_name, stack_depth=1, input_overrides=Int[])
476+
if stack_depth > 2
477+
error("internalize is recursing too much")
478+
return nothing
479+
end
476480
external_name = nothing
477481
try
478482
body_graph = Graph()
483+
body_graph.parent = ParentGraph(outer_graph, loop_name)
479484
as_default(body_graph) do
480485
inner_vars = []
481486
for var in var_list
@@ -488,6 +493,7 @@ function internalize(outer_graph, body_func, cond_func, var_list, input_override
488493
body_func(inner_vars...)
489494
end
490495
cond_graph = Graph()
496+
cond_graph.parent = ParentGraph(outer_graph, loop_name)
491497
as_default(cond_graph) do
492498
inner_vars = []
493499
for var in var_list
@@ -508,6 +514,7 @@ function internalize(outer_graph, body_func, cond_func, var_list, input_override
508514
end
509515
if external_name !== nothing
510516
external_tensor = get_tensor_by_name(outer_graph, external_name)
517+
@show external_tensor
511518
new_var_list = copy(var_list)
512519
push!(new_var_list, external_tensor)
513520
push!(input_overrides, length(new_var_list))
@@ -519,13 +526,14 @@ function internalize(outer_graph, body_func, cond_func, var_list, input_override
519526
function new_cond_func(vars...)
520527
cond_func((vars[1:end-1])...)
521528
end
522-
internalize(outer_graph, new_body_func, new_cond_func, new_var_list, input_overrides)
529+
internalize(outer_graph, new_body_func, new_cond_func, new_var_list, loop_name, stack_depth+1, input_overrides)
523530
else
524531
return WhileGraph(body_func, cond_func, var_list, input_overrides)
525532
end
526533
end
527534

528535
function add_overrides(overrides, variables, inputs)
536+
529537
for override in overrides#internalized_graph.input_overrides
530538
add_input_override(get_def_graph(), variables[override], inputs[override])
531539
end
@@ -538,31 +546,39 @@ function while_loop(condition, body, variables; name=nothing, options=WhileLoopO
538546
# TODO: fix underlying GC issue
539547
n_variable_original = length(variables)
540548
variables = Tensor.(variables)
541-
internalized_graph = internalize(get_def_graph(), body, condition, variables)
549+
name === nothing && (name = get_name("while"))
550+
name = String(name)
551+
internalized_graph = internalize(get_def_graph(), body, condition, variables, name)
542552
body = internalized_graph.body_func
543553
condition = internalized_graph.cond_func
544554
variables = internalized_graph.vars
555+
identity_variables = identity.(variables)
545556
gc_enable(false)
546557

547-
name === nothing && (name = get_name("while"))
548-
name = String(name)
558+
549559
graph = get_def_graph()
550-
params = new_while(graph, variables)
560+
params = new_while(graph, identity_variables)
551561
params.name = pointer(name)
552562
n_inputs = length(variables)
553563
cond_inputs_c = unsafe_wrap(Array, params.cond_inputs, n_inputs)
554564
cond_inputs = Tensor.(cond_inputs_c)
555565
local cond_output
556-
as_default(Graph(params.cond_graph)) do
566+
cond_graph = Graph(params.cond_graph)
567+
cond_graph.parent = ParentGraph(graph, name)
568+
as_default(cond_graph) do
557569
add_overrides(internalized_graph.input_overrides, variables, cond_inputs)
558570
cond_output = condition(cond_inputs...)
559571
end
560572
params.cond_output = TF_Output(cond_output)
561573
body_inputs_c = unsafe_wrap(Array, params.body_inputs, n_inputs)
562574
body_inputs = Tensor.(body_inputs_c)
563575
local body_outputs
564-
as_default(Graph(params.body_graph)) do
576+
body_graph = Graph(params.body_graph)
577+
body_graph.parent = ParentGraph(graph, name)
578+
579+
as_default(body_graph) do
565580
add_overrides(internalized_graph.input_overrides, variables, body_inputs)
581+
566582
body_outputs = body(body_inputs...)
567583
end
568584
body_outputs_c = unsafe_wrap(Array, params.body_outputs, n_inputs)

src/variable.jl

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -67,14 +67,27 @@ function Variable(initial_value; name="", trainable=true, literal_name=false)
6767
if !literal_name
6868
name = tf.get_name(name)
6969
end
70-
self.var_node = tf.Ops.variable_v2(name=name, dtype=eltype(initial_value), shape=tf.TensorShape([size(initial_value)...]))
71-
72-
self.assign_node = tf.Ops.assign(tf.Tensor(self.var_node), initial_value, name="$name/Assign")
73-
tf.add_to_collection(:Variables, self)
74-
if trainable
75-
tf.add_to_collection(:TrainableVariables, self)
70+
graph = tf.get_def_graph()
71+
if graph.parent === nothing
72+
self.var_node = tf.Ops.variable_v2(name=name, dtype=eltype(initial_value), shape=tf.TensorShape([size(initial_value)...]))
73+
74+
self.assign_node = tf.Ops.assign(tf.Tensor(self.var_node), initial_value, name="$name/Assign")
75+
tf.add_to_collection(:Variables, self)
76+
if trainable
77+
tf.add_to_collection(:TrainableVariables, self)
78+
end
79+
return self
80+
else
81+
parent_graph = graph.parent.graph
82+
base_name = name[(length(graph.parent.prefix)+2):end] # maybe use regex instead
83+
parent_var = tf.get_tensor_by_name(parent_graph, base_name)
84+
if parent_var === nothing
85+
tf.as_default(parent_graph) do
86+
parent_var = Variable(initial_value; name=base_name, trainable=trainable, literal_name=literal_name)
87+
end
88+
end
89+
return parent_var
7690
end
77-
return self
7891
end
7992

8093
"""

0 commit comments

Comments
 (0)