16
16
package org .tensorflow ;
17
17
18
18
import static org .tensorflow .Graph .resolveOutputs ;
19
- import static org .tensorflow .internal .c_api .global .tensorflow .TF_CloseSession ;
20
- import static org .tensorflow .internal .c_api .global .tensorflow .TF_DeleteSession ;
21
- import static org .tensorflow .internal .c_api .global .tensorflow .TF_NewSession ;
19
+ import static org .tensorflow .internal .c_api .global .tensorflow .TF_OperationGetAttrType ;
22
20
import static org .tensorflow .internal .c_api .global .tensorflow .TF_SessionRun ;
23
21
import static org .tensorflow .internal .c_api .global .tensorflow .TF_SetConfig ;
24
22
38
36
import org .tensorflow .internal .c_api .TF_SessionOptions ;
39
37
import org .tensorflow .internal .c_api .TF_Status ;
40
38
import org .tensorflow .internal .c_api .TF_Tensor ;
39
+ import org .tensorflow .internal .types .registry .TensorTypeRegistry ;
41
40
import org .tensorflow .op .Op ;
41
+ import org .tensorflow .op .Ops ;
42
+ import org .tensorflow .op .core .ReadVariableOp ;
42
43
import org .tensorflow .proto .framework .ConfigProto ;
44
+ import org .tensorflow .proto .framework .DataType ;
43
45
import org .tensorflow .proto .framework .RunMetadata ;
44
46
import org .tensorflow .proto .framework .RunOptions ;
45
47
import org .tensorflow .proto .util .SaverDef ;
@@ -192,6 +194,11 @@ public Runner feed(String operation, int index, Tensor t) {
192
194
* @return this session runner
193
195
*/
194
196
public Runner feed (Operand <?> operand , Tensor t ) {
197
+ if (operand .env () != graph ) {
198
+ throw new IllegalStateException ("Can't feed value for operand " + operand + ", it is from " +
199
+ (operand .env ().isEager () ? "an eager session" : "a different graph" ) + "." );
200
+ }
201
+
195
202
inputs .add (operand .asOutput ());
196
203
inputTensors .add (t );
197
204
return this ;
@@ -200,6 +207,8 @@ public Runner feed(Operand<?> operand, Tensor t) {
200
207
/**
201
208
* Make {@link #run()} return the output of {@code operation}.
202
209
*
210
+ * If the output is a resource variable, will fetch the value.
211
+ *
203
212
* @param operation Is either the string name of the operation, in which case this method is a shorthand for {@code
204
213
* fetch(operation, 0)}, or it is a string of the form
205
214
* <tt>operation_name:output_index</tt> , in which case this method acts like {@code
@@ -215,6 +224,8 @@ public Runner fetch(String operation) {
215
224
/**
216
225
* Make {@link #run()} return the {@code index}-th output of {@code operation}.
217
226
*
227
+ * If the output is a resource variable, will fetch the value.
228
+ *
218
229
* <p>Operations in a {@link Graph} can have multiple outputs, {@code index} identifies which
219
230
* one to return.
220
231
*
@@ -225,24 +236,61 @@ public Runner fetch(String operation) {
225
236
*/
226
237
public Runner fetch (String operation , int index ) {
227
238
Operation op = graph .operationOrThrow (operation );
228
- outputs .add (op .output (index ));
229
- return this ;
239
+ return fetch (op .output (index ));
230
240
}
231
241
232
242
/**
233
243
* Makes {@link #run()} return the Tensor referred to by {@code output}.
234
244
*
245
+ * If {@code output} is a resource variable, will fetch the value.
246
+ *
235
247
* @param output the node to fetch the tensor from
236
248
* @return this session runner
237
249
*/
238
250
public Runner fetch (Output <?> output ) {
239
- outputs .add (output );
251
+ if (output .env () != graph ) {
252
+ throw new IllegalStateException ("Can't fetch output " + output + ", it is from " +
253
+ (output .env ().isEager () ? "an eager session" : "a different graph" ) + "." );
254
+ }
255
+
256
+ if (output .dataType () == DataType .DT_RESOURCE ) {
257
+ int [] rawDt = new int [1 ];
258
+
259
+ GraphOperation graphOp = (GraphOperation ) output .op ();
260
+
261
+ try (PointerScope scope = new PointerScope ()) {
262
+ TF_Status status = TF_Status .newStatus ();
263
+ TF_OperationGetAttrType (graphOp .getUnsafeNativeHandle (), "dtype" , rawDt , status );
264
+ status .throwExceptionIfNotOK ();
265
+ }
266
+
267
+ DataType valueDt = DataType .forNumber (rawDt [0 ]);
268
+
269
+ Operand <?> read = null ;
270
+ for (GraphOperation op : graphOp .consumers ()) {
271
+ if (op .dtype (0 ) == valueDt && op .type ().equals (ReadVariableOp .OP_NAME )) {
272
+ read = op .output (0 );
273
+ break ;
274
+ }
275
+ }
276
+
277
+ if (read == null ) {
278
+ read = Ops .create (graph ).withSubScope ("session_reads" ).withName (output .op ().name () + "_read" )
279
+ .readVariableOp (output , TensorTypeRegistry .find (valueDt ).type ());
280
+ }
281
+
282
+ outputs .add (read .asOutput ());
283
+ } else {
284
+ outputs .add (output );
285
+ }
240
286
return this ;
241
287
}
242
288
243
289
/**
244
290
* Makes {@link #run()} return the Tensor referred to by the output of {@code operand}.
245
291
*
292
+ * If {@code operand} is a resource variable, will fetch the value.
293
+ *
246
294
* @param operand the node to fetch the tensor from, as an operand
247
295
* @return this session runner
248
296
*/
@@ -258,9 +306,7 @@ public Runner fetch(Operand<?> operand) {
258
306
* @throws IllegalArgumentException if no operation exists with the provided name
259
307
*/
260
308
public Runner addTarget (String operation ) {
261
- GraphOperation op = graph .operationOrThrow (operation );
262
- targets .add (op );
263
- return this ;
309
+ return addTarget (graph .operationOrThrow (operation ));
264
310
}
265
311
266
312
/**
@@ -269,13 +315,12 @@ public Runner addTarget(String operation) {
269
315
* @param operation the operation to execute
270
316
* @return this session runner
271
317
* @throws IllegalArgumentException if the operation is not a {@link GraphOperation}
318
+ * @throws IllegalStateException if the operation is not from the session's graph.
272
319
*/
273
320
public Runner addTarget (Operation operation ) {
274
- if (!(operation instanceof GraphOperation )) {
275
- throw new IllegalArgumentException (
276
- "Operation of type "
277
- + operation .getClass ().getName ()
278
- + " is not supported in graph sessions" );
321
+ if (operation .env () != graph ) {
322
+ throw new IllegalStateException ("Can't target operation " + operation + ", it is from " +
323
+ (operation .env ().isEager () ? "an eager session" : "a different graph" ) + "." );
279
324
}
280
325
targets .add ((GraphOperation ) operation );
281
326
return this ;
@@ -594,12 +639,12 @@ private static void delete(TF_Session handle) {
594
639
*
595
640
* @param handle to the C API TF_Session object (Session.nativeHandle)
596
641
* @param runOptions A RunOptions protocol buffer, or null
597
- * @param inputOpHandles (see inputOpIndices)
598
- * @param inputOpIndices (see inputTensorHandles)
599
642
* @param inputTensorHandles together with inputOpHandles and inputOpIndices specifies the values that are being "fed"
600
643
* (do not need to be computed) during graph execution. inputTensorHandles[i] (which corresponds to a
601
644
* Tensor.nativeHandle) is considered to be the inputOpIndices[i]-th output of the Operation inputOpHandles[i]. Thus,
602
645
* it is required that inputOpHandles.length == inputOpIndices.length == inputTensorHandles.length.
646
+ * @param inputOpHandles (see inputOpIndices)
647
+ * @param inputOpIndices (see inputTensorHandles)
603
648
* @param outputOpHandles (see outputOpIndices)
604
649
* @param outputOpIndices together with outputOpHandles identifies the set of values that should be computed. The
605
650
* outputOpIndices[i]-th output of the Operation outputOpHandles[i], It is required that outputOpHandles.length ==
0 commit comments