16
16
package org .tensorflow ;
17
17
18
18
import static org .tensorflow .Graph .resolveOutputs ;
19
- import static org .tensorflow .internal .c_api .global .tensorflow .TF_OperationGetAttrType ;
19
+ import static org .tensorflow .internal .c_api .global .tensorflow .TF_CloseSession ;
20
+ import static org .tensorflow .internal .c_api .global .tensorflow .TF_DeleteSession ;
21
+ import static org .tensorflow .internal .c_api .global .tensorflow .TF_NewSession ;
20
22
import static org .tensorflow .internal .c_api .global .tensorflow .TF_SessionRun ;
21
23
import static org .tensorflow .internal .c_api .global .tensorflow .TF_SetConfig ;
22
24
36
38
import org .tensorflow .internal .c_api .TF_SessionOptions ;
37
39
import org .tensorflow .internal .c_api .TF_Status ;
38
40
import org .tensorflow .internal .c_api .TF_Tensor ;
39
- import org .tensorflow .internal .types .registry .TensorTypeRegistry ;
40
41
import org .tensorflow .op .Op ;
41
- import org .tensorflow .op .Ops ;
42
- import org .tensorflow .op .core .ReadVariableOp ;
43
42
import org .tensorflow .proto .framework .ConfigProto ;
44
- import org .tensorflow .proto .framework .DataType ;
45
43
import org .tensorflow .proto .framework .RunMetadata ;
46
44
import org .tensorflow .proto .framework .RunOptions ;
47
45
import org .tensorflow .proto .util .SaverDef ;
@@ -194,11 +192,6 @@ public Runner feed(String operation, int index, Tensor t) {
194
192
* @return this session runner
195
193
*/
196
194
public Runner feed (Operand <?> operand , Tensor t ) {
197
- if (operand .env () != graph ) {
198
- throw new IllegalStateException ("Can't feed value for operand " + operand + ", it is from " +
199
- (operand .env ().isEager () ? "an eager session" : "a different graph" ) + "." );
200
- }
201
-
202
195
inputs .add (operand .asOutput ());
203
196
inputTensors .add (t );
204
197
return this ;
@@ -207,8 +200,6 @@ public Runner feed(Operand<?> operand, Tensor t) {
207
200
/**
208
201
* Make {@link #run()} return the output of {@code operation}.
209
202
*
210
- * If the output is a resource variable, will fetch the value.
211
- *
212
203
* @param operation Is either the string name of the operation, in which case this method is a shorthand for {@code
213
204
* fetch(operation, 0)}, or it is a string of the form
214
205
* <tt>operation_name:output_index</tt> , in which case this method acts like {@code
@@ -224,8 +215,6 @@ public Runner fetch(String operation) {
224
215
/**
225
216
* Make {@link #run()} return the {@code index}-th output of {@code operation}.
226
217
*
227
- * If the output is a resource variable, will fetch the value.
228
- *
229
218
* <p>Operations in a {@link Graph} can have multiple outputs, {@code index} identifies which
230
219
* one to return.
231
220
*
@@ -236,61 +225,24 @@ public Runner fetch(String operation) {
236
225
*/
237
226
public Runner fetch (String operation , int index ) {
238
227
Operation op = graph .operationOrThrow (operation );
239
- return fetch (op .output (index ));
228
+ outputs .add (op .output (index ));
229
+ return this ;
240
230
}
241
231
242
232
/**
243
233
* Makes {@link #run()} return the Tensor referred to by {@code output}.
244
234
*
245
- * If {@code output} is a resource variable, will fetch the value.
246
- *
247
235
* @param output the node to fetch the tensor from
248
236
* @return this session runner
249
237
*/
250
238
public Runner fetch (Output <?> output ) {
251
- if (output .env () != graph ) {
252
- throw new IllegalStateException ("Can't fetch output " + output + ", it is from " +
253
- (output .env ().isEager () ? "an eager session" : "a different graph" ) + "." );
254
- }
255
-
256
- if (output .dataType () == DataType .DT_RESOURCE ) {
257
- int [] rawDt = new int [1 ];
258
-
259
- GraphOperation graphOp = (GraphOperation ) output .op ();
260
-
261
- try (PointerScope scope = new PointerScope ()) {
262
- TF_Status status = TF_Status .newStatus ();
263
- TF_OperationGetAttrType (graphOp .getUnsafeNativeHandle (), "dtype" , rawDt , status );
264
- status .throwExceptionIfNotOK ();
265
- }
266
-
267
- DataType valueDt = DataType .forNumber (rawDt [0 ]);
268
-
269
- Operand <?> read = null ;
270
- for (GraphOperation op : graphOp .consumers ()) {
271
- if (op .dtype (0 ) == valueDt && op .type ().equals (ReadVariableOp .OP_NAME )) {
272
- read = op .output (0 );
273
- break ;
274
- }
275
- }
276
-
277
- if (read == null ) {
278
- read = Ops .create (graph ).withSubScope ("session_reads" ).withName (output .op ().name () + "_read" )
279
- .readVariableOp (output , TensorTypeRegistry .find (valueDt ).type ());
280
- }
281
-
282
- outputs .add (read .asOutput ());
283
- } else {
284
- outputs .add (output );
285
- }
239
+ outputs .add (output );
286
240
return this ;
287
241
}
288
242
289
243
/**
290
244
* Makes {@link #run()} return the Tensor referred to by the output of {@code operand}.
291
245
*
292
- * If {@code operand} is a resource variable, will fetch the value.
293
- *
294
246
* @param operand the node to fetch the tensor from, as an operand
295
247
* @return this session runner
296
248
*/
@@ -306,7 +258,9 @@ public Runner fetch(Operand<?> operand) {
306
258
* @throws IllegalArgumentException if no operation exists with the provided name
307
259
*/
308
260
public Runner addTarget (String operation ) {
309
- return addTarget (graph .operationOrThrow (operation ));
261
+ GraphOperation op = graph .operationOrThrow (operation );
262
+ targets .add (op );
263
+ return this ;
310
264
}
311
265
312
266
/**
@@ -315,12 +269,13 @@ public Runner addTarget(String operation) {
315
269
* @param operation the operation to execute
316
270
* @return this session runner
317
271
* @throws IllegalArgumentException if the operation is not a {@link GraphOperation}
318
- * @throws IllegalStateException if the operation is not from the session's graph.
319
272
*/
320
273
public Runner addTarget (Operation operation ) {
321
- if (operation .env () != graph ) {
322
- throw new IllegalStateException ("Can't target operation " + operation + ", it is from " +
323
- (operation .env ().isEager () ? "an eager session" : "a different graph" ) + "." );
274
+ if (!(operation instanceof GraphOperation )) {
275
+ throw new IllegalArgumentException (
276
+ "Operation of type "
277
+ + operation .getClass ().getName ()
278
+ + " is not supported in graph sessions" );
324
279
}
325
280
targets .add ((GraphOperation ) operation );
326
281
return this ;
@@ -639,12 +594,12 @@ private static void delete(TF_Session handle) {
639
594
*
640
595
* @param handle to the C API TF_Session object (Session.nativeHandle)
641
596
* @param runOptions A RunOptions protocol buffer, or null
597
+ * @param inputOpHandles (see inputOpIndices)
598
+ * @param inputOpIndices (see inputTensorHandles)
642
599
* @param inputTensorHandles together with inputOpHandles and inputOpIndices specifies the values that are being "fed"
643
600
* (do not need to be computed) during graph execution. inputTensorHandles[i] (which corresponds to a
644
601
* Tensor.nativeHandle) is considered to be the inputOpIndices[i]-th output of the Operation inputOpHandles[i]. Thus,
645
602
* it is required that inputOpHandles.length == inputOpIndices.length == inputTensorHandles.length.
646
- * @param inputOpHandles (see inputOpIndices)
647
- * @param inputOpIndices (see inputTensorHandles)
648
603
* @param outputOpHandles (see outputOpIndices)
649
604
* @param outputOpIndices together with outputOpHandles identifies the set of values that should be computed. The
650
605
* outputOpIndices[i]-th output of the Operation outputOpHandles[i], It is required that outputOpHandles.length ==
0 commit comments