19
19
import static org .tensorflow .internal .c_api .global .tensorflow .TF_GraphToFunction ;
20
20
21
21
import java .util .ArrayList ;
22
- import java .util .Arrays ;
23
22
import java .util .Collection ;
24
23
import java .util .Collections ;
25
24
import java .util .HashSet ;
25
+ import java .util .Iterator ;
26
26
import java .util .LinkedHashMap ;
27
27
import java .util .List ;
28
28
import java .util .Map ;
66
66
* Map<String, Tensor> outputTensorMap = myFunction.call(inputTensorMap);
67
67
* }</pre>
68
68
*/
69
- public class ConcreteFunction implements AutoCloseable , TensorFunction {
69
+ public final class ConcreteFunction implements AutoCloseable , TensorFunction {
70
70
71
71
/**
72
72
* Creates a function by building a new graph.
@@ -220,11 +220,11 @@ public String toString() {
220
220
public Map <String , Operand <?>> call (Scope scope , Map <String , Operand <?>> arguments ) {
221
221
List <Operand <?>> inputList = new ArrayList <>(signature .inputNames ().size ());
222
222
223
- for (String inputName : signature () .inputNames ()) {
223
+ for (String inputName : signature .inputNames ()) {
224
224
if (!arguments .containsKey (inputName )) {
225
225
throw new IllegalArgumentException (
226
226
"Function "
227
- + signature () .methodName ()
227
+ + signature .methodName ()
228
228
+ " has parameter \" "
229
229
+ inputName
230
230
+ "\" , but no argument was passed for it." );
@@ -241,30 +241,30 @@ public Map<String, Operand<?>> call(Scope scope, Map<String, Operand<?>> argumen
241
241
}
242
242
243
243
List <Output <?>> outputList =
244
- PartitionedCall .create (
245
- scope ,
246
- inputList ,
247
- Arrays .stream (outputDtypes )
248
- .map (x -> TensorTypeRegistry .find (x ).type ())
249
- .collect (Collectors .toList ()),
250
- this )
251
- .output ();
252
-
253
- Map <String , Operand <?>> namedOutputs = new LinkedHashMap <>(signature ().outputNames ().size ());
254
-
255
- List <String > outputNames = new ArrayList <>(signature ().outputNames ());
256
- for (int i = 0 ; i < outputNames .size (); i ++) {
257
- String outputName = outputNames .get (i );
258
-
259
- if (i > outputList .size ()) {
260
- throw new IllegalStateException (
261
- "Somehow, not all required outputs were returned from the function" );
262
- }
244
+ PartitionedCall .create (scope , inputList , outputTypes , this ).output ();
263
245
246
+ if (signature .outputNames ().size () == 0 ) {
247
+ return Collections .emptyMap ();
248
+ }
249
+ if (signature .outputNames ().size () == 1 ) {
250
+ return Collections .singletonMap (signature .outputNames ().iterator ().next (), outputList .get (0 ));
251
+ }
252
+ if (outputList .size () < signature .outputNames ().size ()) {
253
+ throw new IllegalStateException (
254
+ "Somehow, not all required outputs were returned from the function"
255
+ + "(expected: "
256
+ + signature .outputNames ().size ()
257
+ + ", returned: "
258
+ + outputList .size ()
259
+ + ")" );
260
+ }
261
+ Map <String , Operand <?>> namedOutputs = new LinkedHashMap <>(signature .outputNames ().size ());
262
+ Iterator <String > outputNames = signature .outputNames ().iterator ();
263
+ for (int i = 0 ; outputNames .hasNext (); i ++) {
264
+ String outputName = outputNames .next ();
264
265
Operand <?> output = outputList .get (i );
265
266
namedOutputs .put (outputName , output );
266
267
}
267
-
268
268
return Collections .unmodifiableMap (namedOutputs );
269
269
}
270
270
@@ -291,10 +291,7 @@ public Operand<?> call(Scope scope, Operand<?> argument) {
291
291
}
292
292
String outputName = signatureDef .getOutputsMap ().keySet ().iterator ().next ();
293
293
294
- Map <String , Operand <?>> inputMap = new LinkedHashMap <>();
295
- inputMap .put (inputName , argument );
296
-
297
- return call (scope , inputMap ).get (outputName );
294
+ return call (scope , Collections .singletonMap (inputName , argument )).get (outputName );
298
295
}
299
296
300
297
@ Override
@@ -395,8 +392,7 @@ static ConcreteFunction fromNativeHandle(
395
392
private final NativeFunction nativeFunction ;
396
393
private final PointerScope scope ;
397
394
private final Set <TF_Function > dependencies ;
398
- private final DataType [] inputDtypes ;
399
- private final DataType [] outputDtypes ;
395
+ private final List <Class <? extends TType >> outputTypes ;
400
396
401
397
/** All native functions should have deallocators registered */
402
398
private ConcreteFunction (
@@ -405,7 +401,7 @@ private ConcreteFunction(
405
401
this .nativeFunction = nativeFunction ;
406
402
this .dependencies = Collections .unmodifiableSet (dependencies );
407
403
408
- if (this . signature .getInputs ().size ()
404
+ if (signature .getInputs ().size ()
409
405
!= nativeFunction .getFunctionDef ().getSignature ().getInputArgCount ()) {
410
406
throw new IllegalArgumentException (
411
407
"Signature must have the same number of inputs as the native function. Expected "
@@ -414,7 +410,7 @@ private ConcreteFunction(
414
410
+ this .signature .getInputs ().size ());
415
411
}
416
412
417
- if (this . signature .getOutputs ().size ()
413
+ if (signature .getOutputs ().size ()
418
414
!= nativeFunction .getFunctionDef ().getSignature ().getOutputArgCount ()) {
419
415
throw new IllegalArgumentException (
420
416
"New signature must have the same number of outputs as the native function. Expected "
@@ -423,10 +419,8 @@ private ConcreteFunction(
423
419
+ this .signature .getOutputs ().size ());
424
420
}
425
421
426
- inputDtypes =
427
- this .signature .getInputs ().values ().stream ().map (x -> x .dataType ).toArray (DataType []::new );
428
-
429
- List <DataType > inputs = Arrays .asList (inputDtypes );
422
+ List <DataType > inputs =
423
+ signature .getInputs ().values ().stream ().map (x -> x .dataType ).collect (Collectors .toList ());
430
424
List <DataType > nativeInputs =
431
425
nativeFunction .getFunctionDef ().getSignature ().getInputArgList ().stream ()
432
426
.map (ArgDef ::getType )
@@ -440,10 +434,8 @@ private ConcreteFunction(
440
434
+ inputs );
441
435
}
442
436
443
- outputDtypes =
444
- signature ().getOutputs ().values ().stream ().map (x -> x .dataType ).toArray (DataType []::new );
445
-
446
- List <DataType > outputs = Arrays .asList (outputDtypes );
437
+ List <DataType > outputs =
438
+ signature .getOutputs ().values ().stream ().map (x -> x .dataType ).collect (Collectors .toList ());
447
439
List <DataType > nativeOutputs =
448
440
nativeFunction .getFunctionDef ().getSignature ().getOutputArgList ().stream ()
449
441
.map (ArgDef ::getType )
@@ -457,6 +449,9 @@ private ConcreteFunction(
457
449
+ outputs );
458
450
}
459
451
452
+ outputTypes =
453
+ outputs .stream ().map (x -> TensorTypeRegistry .find (x ).type ()).collect (Collectors .toList ());
454
+
460
455
try (PointerScope scope = new PointerScope ()) {
461
456
this .scope = scope ;
462
457
scope .extend ();
@@ -469,6 +464,8 @@ private ConcreteFunction(
469
464
* FIXME: This causes native errors when I use it (Linux GPU, 6.1 CC), but I'm leaving it because
470
465
* how to enable XLA JIT is extremely non-obvious.
471
466
*
467
+ * <p>See https://github.com/tensorflow/java/issues/347
468
+ *
472
469
* <p>Causes {@code OP_REQUIRES failed at xla_ops.cc:363 : Not found: could not find registered
473
470
* platform with id: 0x7f75af03e6e8} (it's a warning, but the resulting TF_Status fails).
474
471
*/
0 commit comments