Skip to content

Commit 3ded349

Browse files
committed
Gradient actually works
1 parent 2b06ab1 commit 3ded349

File tree

2 files changed

+16
-13
lines changed

2 files changed

+16
-13
lines changed

src/core.jl

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,7 @@ function get_collection end
333333
return g.collections[name]
334334
end
335335

336-
const DEBUG_EXTEND_GRAPH = false
336+
const DEBUG_EXTEND_GRAPH = true
337337

338338
function Base.convert(::Type{tensorflow.NodeDef}, proto::Vector{UInt8})
339339
b = IOBuffer()
@@ -351,14 +351,15 @@ end
351351
ph_names = Set{String}()
352352
for node_bytes in node_defs
353353
node_def = convert(tensorflow.NodeDef, node_bytes)
354+
@show node_def
354355
if isnull(get_node_by_name(graph, node_def.name))
355356
# First try to directly add this node to the graph
356-
try
357-
new_op = Operation(node_def)
358-
continue
359-
catch err
360-
DEBUG_EXTEND_GRAPH && warn(err)
361-
end
357+
# try
358+
# new_op = Operation(node_def)
359+
# continue
360+
# catch err
361+
# DEBUG_EXTEND_GRAPH && warn(err)
362+
# end
362363

363364
# If that doesn't work (for example, the node has a
364365
# back edge), then import the node instead.

src/ops/control_flow.jl

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ Example using shape_invariants:
248248
shape_invariants=[i0.get_shape(), tensor_shape.TensorShape([None, 2])])
249249
```
250250
"""
251-
@op function while_loop(condition, body, variables; name=nothing, shape_invariants=nothing,
251+
@op function jl_while_loop(condition, body, variables; name=nothing, shape_invariants=nothing,
252252
parallel_iterations=10, back_prop=true, swap_memory=false)
253253
g = Graph()
254254
def_graph = get_def_graph()
@@ -445,7 +445,8 @@ function WhileLoopOptions(;parallel_iterations=10, back_prop=true, swap_memory=f
445445
WhileLoopOptions(parallel_iterations, back_prop, swap_memory)
446446
end
447447

448-
function c_while(condition, body, variables; name=nothing, options=WhileLoopOptions())
448+
function while_loop(condition, body, variables; name=nothing, options=WhileLoopOptions())
449+
variables = Tensor.(variables)
449450
name === nothing && (name = "while")
450451
name = String(name)
451452
graph = get_def_graph()
@@ -485,7 +486,7 @@ function create_while_context(graph, name, n_inputs; options=WhileLoopOptions())
485486
loop_exit_names=String[])
486487
context_matcher = Regex("^$(name)/")
487488
for op in get_operations(graph)
488-
@show op
489+
# @show op
489490
if ismatch(context_matcher, get_def(op).name)
490491
def = get_def(op)
491492
n_outputs = length(get_op_def(def.op).output_arg)
@@ -494,18 +495,19 @@ function create_while_context(graph, name, n_inputs; options=WhileLoopOptions())
494495
end
495496
end
496497
end
497-
push!(ctx.values_def.values, "$(name)/merge0:1")
498-
push!(ctx.values_def.values, "$(name)/switch0:1")
498+
# push!(ctx.values_def.values, "$(name)/merge0:1")
499+
# push!(ctx.values_def.values, "$(name)/switch0:1")
499500
set_field!(ctx, :pivot_for_pred_name, "$(name)/merge0:0")
500501
switch_name = "$(name)/switch0"
501502
switch_op = get_node_by_name(switch_name) |> get |> get_def
502503
# We assume the pivot tensor is the second input to the switch statement.
503-
# The first input is the result of the merge.
504+
# The first input is the result of the merge.
504505
cond_op = switch_op.input[2]
505506
set_field!(ctx, :pivot_for_body_name, "$(switch_name):0")
506507
set_field!(ctx, :pivot_name, "$(cond_op):0")
507508
for i in 1:n_inputs
508509
push!(ctx.loop_exit_names, "$(name)/exit$(i-1):0")
509510
end
511+
# dump(ctx)
510512
return ctx
511513
end

0 commit comments

Comments
 (0)