Skip to content

Commit 3f8272f

Browse files
karllessardrnett
andauthored
ConcreteFunction fix and performance improvements (#364)
Signed-off-by: Ryan Nett <[email protected]> Co-authored-by: Ryan Nett <[email protected]>
1 parent 0aaba03 commit 3f8272f

File tree

2 files changed

+75
-53
lines changed

2 files changed

+75
-53
lines changed

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java

Lines changed: 37 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@
1919
import static org.tensorflow.internal.c_api.global.tensorflow.TF_GraphToFunction;
2020

2121
import java.util.ArrayList;
22-
import java.util.Arrays;
2322
import java.util.Collection;
2423
import java.util.Collections;
2524
import java.util.HashSet;
25+
import java.util.Iterator;
2626
import java.util.LinkedHashMap;
2727
import java.util.List;
2828
import java.util.Map;
@@ -66,7 +66,7 @@
6666
* Map<String, Tensor> outputTensorMap = myFunction.call(inputTensorMap);
6767
* }</pre>
6868
*/
69-
public class ConcreteFunction implements AutoCloseable, TensorFunction {
69+
public final class ConcreteFunction implements AutoCloseable, TensorFunction {
7070

7171
/**
7272
* Creates a function by building a new graph.
@@ -220,11 +220,11 @@ public String toString() {
220220
public Map<String, Operand<?>> call(Scope scope, Map<String, Operand<?>> arguments) {
221221
List<Operand<?>> inputList = new ArrayList<>(signature.inputNames().size());
222222

223-
for (String inputName : signature().inputNames()) {
223+
for (String inputName : signature.inputNames()) {
224224
if (!arguments.containsKey(inputName)) {
225225
throw new IllegalArgumentException(
226226
"Function "
227-
+ signature().methodName()
227+
+ signature.methodName()
228228
+ " has parameter \""
229229
+ inputName
230230
+ "\", but no argument was passed for it.");
@@ -241,30 +241,30 @@ public Map<String, Operand<?>> call(Scope scope, Map<String, Operand<?>> argumen
241241
}
242242

243243
List<Output<?>> outputList =
244-
PartitionedCall.create(
245-
scope,
246-
inputList,
247-
Arrays.stream(outputDtypes)
248-
.map(x -> TensorTypeRegistry.find(x).type())
249-
.collect(Collectors.toList()),
250-
this)
251-
.output();
252-
253-
Map<String, Operand<?>> namedOutputs = new LinkedHashMap<>(signature().outputNames().size());
254-
255-
List<String> outputNames = new ArrayList<>(signature().outputNames());
256-
for (int i = 0; i < outputNames.size(); i++) {
257-
String outputName = outputNames.get(i);
258-
259-
if (i > outputList.size()) {
260-
throw new IllegalStateException(
261-
"Somehow, not all required outputs were returned from the function");
262-
}
244+
PartitionedCall.create(scope, inputList, outputTypes, this).output();
263245

246+
if (signature.outputNames().size() == 0) {
247+
return Collections.emptyMap();
248+
}
249+
if (signature.outputNames().size() == 1) {
250+
return Collections.singletonMap(signature.outputNames().iterator().next(), outputList.get(0));
251+
}
252+
if (outputList.size() < signature.outputNames().size()) {
253+
throw new IllegalStateException(
254+
"Somehow, not all required outputs were returned from the function"
255+
+ "(expected: "
256+
+ signature.outputNames().size()
257+
+ ", returned: "
258+
+ outputList.size()
259+
+ ")");
260+
}
261+
Map<String, Operand<?>> namedOutputs = new LinkedHashMap<>(signature.outputNames().size());
262+
Iterator<String> outputNames = signature.outputNames().iterator();
263+
for (int i = 0; outputNames.hasNext(); i++) {
264+
String outputName = outputNames.next();
264265
Operand<?> output = outputList.get(i);
265266
namedOutputs.put(outputName, output);
266267
}
267-
268268
return Collections.unmodifiableMap(namedOutputs);
269269
}
270270

@@ -291,10 +291,7 @@ public Operand<?> call(Scope scope, Operand<?> argument) {
291291
}
292292
String outputName = signatureDef.getOutputsMap().keySet().iterator().next();
293293

294-
Map<String, Operand<?>> inputMap = new LinkedHashMap<>();
295-
inputMap.put(inputName, argument);
296-
297-
return call(scope, inputMap).get(outputName);
294+
return call(scope, Collections.singletonMap(inputName, argument)).get(outputName);
298295
}
299296

300297
@Override
@@ -395,8 +392,7 @@ static ConcreteFunction fromNativeHandle(
395392
private final NativeFunction nativeFunction;
396393
private final PointerScope scope;
397394
private final Set<TF_Function> dependencies;
398-
private final DataType[] inputDtypes;
399-
private final DataType[] outputDtypes;
395+
private final List<Class<? extends TType>> outputTypes;
400396

401397
/** All native functions should have deallocators registered */
402398
private ConcreteFunction(
@@ -405,7 +401,7 @@ private ConcreteFunction(
405401
this.nativeFunction = nativeFunction;
406402
this.dependencies = Collections.unmodifiableSet(dependencies);
407403

408-
if (this.signature.getInputs().size()
404+
if (signature.getInputs().size()
409405
!= nativeFunction.getFunctionDef().getSignature().getInputArgCount()) {
410406
throw new IllegalArgumentException(
411407
"Signature must have the same number of inputs as the native function. Expected "
@@ -414,7 +410,7 @@ private ConcreteFunction(
414410
+ this.signature.getInputs().size());
415411
}
416412

417-
if (this.signature.getOutputs().size()
413+
if (signature.getOutputs().size()
418414
!= nativeFunction.getFunctionDef().getSignature().getOutputArgCount()) {
419415
throw new IllegalArgumentException(
420416
"New signature must have the same number of outputs as the native function. Expected "
@@ -423,10 +419,8 @@ private ConcreteFunction(
423419
+ this.signature.getOutputs().size());
424420
}
425421

426-
inputDtypes =
427-
this.signature.getInputs().values().stream().map(x -> x.dataType).toArray(DataType[]::new);
428-
429-
List<DataType> inputs = Arrays.asList(inputDtypes);
422+
List<DataType> inputs =
423+
signature.getInputs().values().stream().map(x -> x.dataType).collect(Collectors.toList());
430424
List<DataType> nativeInputs =
431425
nativeFunction.getFunctionDef().getSignature().getInputArgList().stream()
432426
.map(ArgDef::getType)
@@ -440,10 +434,8 @@ private ConcreteFunction(
440434
+ inputs);
441435
}
442436

443-
outputDtypes =
444-
signature().getOutputs().values().stream().map(x -> x.dataType).toArray(DataType[]::new);
445-
446-
List<DataType> outputs = Arrays.asList(outputDtypes);
437+
List<DataType> outputs =
438+
signature.getOutputs().values().stream().map(x -> x.dataType).collect(Collectors.toList());
447439
List<DataType> nativeOutputs =
448440
nativeFunction.getFunctionDef().getSignature().getOutputArgList().stream()
449441
.map(ArgDef::getType)
@@ -457,6 +449,9 @@ private ConcreteFunction(
457449
+ outputs);
458450
}
459451

452+
outputTypes =
453+
outputs.stream().map(x -> TensorTypeRegistry.find(x).type()).collect(Collectors.toList());
454+
460455
try (PointerScope scope = new PointerScope()) {
461456
this.scope = scope;
462457
scope.extend();
@@ -469,6 +464,8 @@ private ConcreteFunction(
469464
* FIXME: This causes native errors when I use it (Linux GPU, 6.1 CC), but I'm leaving it because
470465
* how to enable XLA JIT is extremely non-obvious.
471466
*
467+
* <p>See https://github.com/tensorflow/java/issues/347
468+
*
472469
* <p>Causes {@code OP_REQUIRES failed at xla_ops.cc:363 : Not found: could not find registered
473470
* platform with id: 0x7f75af03e6e8} (it's a warning, but the resulting TF_Status fails).
474471
*/

tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/ConcreteFunctionTest.java

Lines changed: 38 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,26 @@
11
/* Copyright 2020-2021 The TensorFlow Authors. All Rights Reserved.
22
3-
Licensed under the Apache License, Version 2.0 (the "License");
4-
you may not use this file except in compliance with the License.
5-
You may obtain a copy of the License at
6-
7-
http://www.apache.org/licenses/LICENSE-2.0
8-
9-
Unless required by applicable law or agreed to in writing, software
10-
distributed under the License is distributed on an "AS IS" BASIS,
11-
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12-
See the License for the specific language governing permissions and
13-
limitations under the License.
14-
=======================================================================
15-
*/
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
=======================================================================
15+
*/
1616
package org.tensorflow;
1717

1818
import static org.junit.jupiter.api.Assertions.assertEquals;
1919
import static org.junit.jupiter.api.Assertions.assertNotNull;
2020

2121
import java.util.Arrays;
22+
import java.util.HashMap;
23+
import java.util.Map;
2224
import org.junit.jupiter.api.Test;
2325
import org.tensorflow.op.Ops;
2426
import org.tensorflow.op.core.Init;
@@ -27,6 +29,7 @@
2729
import org.tensorflow.op.math.Sub;
2830
import org.tensorflow.proto.framework.DataType;
2931
import org.tensorflow.types.TFloat32;
32+
import org.tensorflow.types.TInt32;
3033

3134
public class ConcreteFunctionTest {
3235

@@ -144,6 +147,28 @@ public void testNestedFunctionGraph() {
144147
}
145148
}
146149

150+
@Test
151+
public void testFunctionWithTwoOutputs() {
152+
ConcreteFunction cf =
153+
ConcreteFunction.create(
154+
tf -> {
155+
Placeholder<TInt32> x = tf.placeholder(TInt32.class);
156+
Operand<TInt32> dblX = tf.math.add(x, x);
157+
Operand<TInt32> tripX = tf.math.add(x, dblX);
158+
return Signature.builder()
159+
.input("x", x)
160+
.output("dbl", dblX)
161+
.output("trpl", tripX)
162+
.build();
163+
});
164+
165+
Map<String, Tensor> inputs = new HashMap<>();
166+
inputs.put("x", TInt32.scalarOf(2));
167+
Map<String, Tensor> outputs = cf.call(inputs);
168+
assertEquals(4, ((TInt32) outputs.get("dbl")).getInt());
169+
assertEquals(6, ((TInt32) outputs.get("trpl")).getInt());
170+
}
171+
147172
private static Signature square(Ops tf) {
148173
Placeholder<TFloat32> input = tf.placeholder(TFloat32.class);
149174
Operand<TFloat32> output = tf.math.square(input);

0 commit comments

Comments
 (0)