Skip to content

WIP: Use C API while loop functions #338

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 15 commits into from
Closed
2 changes: 0 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
deps/usr
deps/downloads
deps/miniconda
src/scratch.jl
tfdocs
docs/build
*.DS_Store
*.*~

*.swp
*.jld
scratch
Expand Down
1 change: 1 addition & 0 deletions src/TensorFlow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -193,5 +193,6 @@ include("summary.jl")
include("deprecated.jl")
include("show.jl")
include("generate_ops.jl")
include("debug.jl")

end
75 changes: 49 additions & 26 deletions src/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,11 @@ TensorShape(t::TensorShape) = copy(t)

function get_shape end

struct ParentGraph
graph::Any # Really, Graph
prefix::String
end

"""
A TensorFlow computation graph
"""
Expand All @@ -249,21 +254,42 @@ mutable struct Graph
shapes::Dict{String, TensorShape}
name_idx::Dict{String, Int}
op_context::OperationContext
input_override::Dict
parent::Union{Void, ParentGraph}

function Graph()
ptr = @tfcall(:TF_NewGraph, Ptr{Void}, ())
function Graph(ptr::Ptr)
collections = Dict{Symbol, Any}()
collections[:Variables] = []
collections[:TrainableVariables] = []
collections[:Summaries] = []
collections[:QueueRunners] = []
collections[:while_context] = []
self = new(ptr, collections, Dict{String, TensorShape}(), Dict{String, Int}(), OperationContext(Vector{Operation}[], String[], tensorflow.WhileContextDef[], Device[], Ref(false)))
finalizer(self, self->begin
@tfcall(:TF_DeleteGraph, Void, (Ptr{Void},), self.ptr)
end)
self = new(ptr, collections, Dict{String, TensorShape}(), Dict{String, Int}(), OperationContext(Vector{Operation}[], String[], tensorflow.WhileContextDef[], Device[], Ref(false)), Dict(), nothing)
self
end

end

function add_input_override(g::Graph, original, override)
if override === nothing
delete!(g.input_override, original)
else
g.input_override[original] = override
end
end

function clear_input_overrides(g::Graph)
empty!(g.input_override)
end


function Graph()
ptr = @tfcall(:TF_NewGraph, Ptr{Void}, ())
self = Graph(ptr)
finalizer(self, self->begin
@tfcall(:TF_DeleteGraph, Void, (Ptr{Void},), self.ptr)
end)
self
end

function Base.show(io::IO, g::Graph)
Expand Down Expand Up @@ -327,7 +353,7 @@ function get_collection end
return g.collections[name]
end

const DEBUG_EXTEND_GRAPH = false
const DEBUG_EXTEND_GRAPH = true

function Base.convert(::Type{tensorflow.NodeDef}, proto::Vector{UInt8})
b = IOBuffer()
Expand All @@ -347,12 +373,12 @@ end
node_def = convert(tensorflow.NodeDef, node_bytes)
if isnull(get_node_by_name(graph, node_def.name))
# First try to directly add this node to the graph
try
new_op = Operation(node_def)
continue
catch err
DEBUG_EXTEND_GRAPH && warn(err)
end
# try
# new_op = Operation(node_def)
# continue
# catch err
# DEBUG_EXTEND_GRAPH && warn(err)
# end

# If that doesn't work (for example, the node has a
# back edge), then import the node instead.
Expand Down Expand Up @@ -397,8 +423,8 @@ end
input_name = new_name
end
node_def.input[i] = input_name

import_options.input_mapping[(new_name, source_port)] = Tensor(get(existing_node), dest_port)
# The 'Any' here is needed to suppress trying to automatically determine the output type of the tensor, which for some reason is crashing.
import_options.input_mapping[(new_name, source_port)] = Tensor{Any}(get(existing_node), dest_port)
new_ph = tensorflow.NodeDef()
set_field!(new_ph, :name, new_name)
if is_control
Expand Down Expand Up @@ -905,9 +931,11 @@ Base.hash(op::Operation, h::UInt) = hash(Operation, hash(op.ptr, h))

struct Port
node_ptr::Ptr{Void}
index::Int
index::Cint
end

const TF_Output = Port

function get_num_inputs(op::Operation)
@tfcall(:TF_OperationNumInputs, Cint, (Ptr{Void},), op.ptr) |> Int
end
Expand Down Expand Up @@ -1106,13 +1134,6 @@ function Operation(ptr::Ptr)
return self
end

struct NodeNameNotFound <: Exception
name::String
end

function Base.show(io::IO, err::NodeNameNotFound)
print(io, "Node $(err.name) not found in graph")
end

get_graph(n::AbstractOperation) = Operation(n).graph

Expand Down Expand Up @@ -1317,7 +1338,9 @@ Port(port::Port) = port
Tensor(p::Port) = Tensor(Operation(p.node_ptr), p.index+1)

function add_input(desc::NodeDescription, input::Union{Tensor, Operation})
@tfcall(:TF_AddInput, Void, (Ptr{Void}, Port), desc.ptr, Port(input))
graph = get(get_graph(desc))
remapped_input = get(graph.input_override, input, input)
@tfcall(:TF_AddInput, Void, (Ptr{Void}, Port), desc.ptr, Port(remapped_input))
end

function add_input(desc::NodeDescription, inputs::Vector)
Expand Down Expand Up @@ -1531,12 +1554,12 @@ end

Returns the tensor with name `name` (in name:port format) in the given graph.

Throws a `NodeNameNotFound` exception if there is no such tensor.
Returns `nothing` if the tensor is not found.
"""
@with_def_graph function get_tensor_by_name(graph::Graph, full_name)
name, port = parse_port_name(full_name)
maybe_node = get_node_by_name(graph, name)
isnull(maybe_node) && throw(NodeNameNotFound(full_name))
isnull(maybe_node) && return nothing
node = get(maybe_node)
return Tensor(node, port)
end
Expand Down
8 changes: 8 additions & 0 deletions src/debug.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
@with_def_graph function show_op_names(g::Graph)
for (i, node) in enumerate(get_def(g).node)
println("$(i): $(node.name)")
for (j, input) in enumerate(node.input)
println(" $(j): $(input)")
end
end
end
Loading