Skip to content

Commit e229028

Browse files
authored
Add fetchVariable method to Session to get value of resource variable (#261)
* Add fetchVariable method to Session to get value of resource variable Signed-off-by: Ryan Nett <[email protected]> * Format Signed-off-by: Ryan Nett <[email protected]> * More Formatting Signed-off-by: Ryan Nett <[email protected]> * Rework, automatically wrap variables in read when fetched Signed-off-by: Ryan Nett <[email protected]> * Forgot to format Signed-off-by: Ryan Nett <[email protected]> * Remove obsolete method Signed-off-by: Ryan Nett <[email protected]> * Small fixes Signed-off-by: Ryan Nett <[email protected]> * Python model loading + variable fetching test Signed-off-by: Ryan Nett <[email protected]>
1 parent cc5c86c commit e229028

File tree

7 files changed

+170
-60
lines changed

7 files changed

+170
-60
lines changed

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java

Lines changed: 61 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,7 @@
1616
package org.tensorflow;
1717

1818
import static org.tensorflow.Graph.resolveOutputs;
19-
import static org.tensorflow.internal.c_api.global.tensorflow.TF_CloseSession;
20-
import static org.tensorflow.internal.c_api.global.tensorflow.TF_DeleteSession;
21-
import static org.tensorflow.internal.c_api.global.tensorflow.TF_NewSession;
19+
import static org.tensorflow.internal.c_api.global.tensorflow.TF_OperationGetAttrType;
2220
import static org.tensorflow.internal.c_api.global.tensorflow.TF_SessionRun;
2321
import static org.tensorflow.internal.c_api.global.tensorflow.TF_SetConfig;
2422

@@ -38,8 +36,12 @@
3836
import org.tensorflow.internal.c_api.TF_SessionOptions;
3937
import org.tensorflow.internal.c_api.TF_Status;
4038
import org.tensorflow.internal.c_api.TF_Tensor;
39+
import org.tensorflow.internal.types.registry.TensorTypeRegistry;
4140
import org.tensorflow.op.Op;
41+
import org.tensorflow.op.Ops;
42+
import org.tensorflow.op.core.ReadVariableOp;
4243
import org.tensorflow.proto.framework.ConfigProto;
44+
import org.tensorflow.proto.framework.DataType;
4345
import org.tensorflow.proto.framework.RunMetadata;
4446
import org.tensorflow.proto.framework.RunOptions;
4547
import org.tensorflow.proto.util.SaverDef;
@@ -192,6 +194,11 @@ public Runner feed(String operation, int index, Tensor t) {
192194
* @return this session runner
193195
*/
194196
public Runner feed(Operand<?> operand, Tensor t) {
197+
if (operand.env() != graph) {
198+
throw new IllegalStateException("Can't feed value for operand " + operand + ", it is from " +
199+
(operand.env().isEager() ? "an eager session" : "a different graph") + ".");
200+
}
201+
195202
inputs.add(operand.asOutput());
196203
inputTensors.add(t);
197204
return this;
@@ -200,6 +207,8 @@ public Runner feed(Operand<?> operand, Tensor t) {
200207
/**
201208
* Make {@link #run()} return the output of {@code operation}.
202209
*
210+
* If the output is a resource variable, will fetch the value.
211+
*
203212
* @param operation Is either the string name of the operation, in which case this method is a shorthand for {@code
204213
* fetch(operation, 0)}, or it is a string of the form
205214
* <tt>operation_name:output_index</tt> , in which case this method acts like {@code
@@ -215,6 +224,8 @@ public Runner fetch(String operation) {
215224
/**
216225
* Make {@link #run()} return the {@code index}-th output of {@code operation}.
217226
*
227+
* If the output is a resource variable, will fetch the value.
228+
*
218229
* <p>Operations in a {@link Graph} can have multiple outputs, {@code index} identifies which
219230
* one to return.
220231
*
@@ -225,24 +236,61 @@ public Runner fetch(String operation) {
225236
*/
226237
public Runner fetch(String operation, int index) {
227238
Operation op = graph.operationOrThrow(operation);
228-
outputs.add(op.output(index));
229-
return this;
239+
return fetch(op.output(index));
230240
}
231241

232242
/**
233243
* Makes {@link #run()} return the Tensor referred to by {@code output}.
234244
*
245+
* If {@code output} is a resource variable, will fetch the value.
246+
*
235247
* @param output the node to fetch the tensor from
236248
* @return this session runner
237249
*/
238250
public Runner fetch(Output<?> output) {
239-
outputs.add(output);
251+
if (output.env() != graph) {
252+
throw new IllegalStateException("Can't fetch output " + output + ", it is from " +
253+
(output.env().isEager() ? "an eager session" : "a different graph") + ".");
254+
}
255+
256+
if (output.dataType() == DataType.DT_RESOURCE) {
257+
int[] rawDt = new int[1];
258+
259+
GraphOperation graphOp = (GraphOperation) output.op();
260+
261+
try (PointerScope scope = new PointerScope()) {
262+
TF_Status status = TF_Status.newStatus();
263+
TF_OperationGetAttrType(graphOp.getUnsafeNativeHandle(), "dtype", rawDt, status);
264+
status.throwExceptionIfNotOK();
265+
}
266+
267+
DataType valueDt = DataType.forNumber(rawDt[0]);
268+
269+
Operand<?> read = null;
270+
for (GraphOperation op : graphOp.consumers()) {
271+
if (op.dtype(0) == valueDt && op.type().equals(ReadVariableOp.OP_NAME)) {
272+
read = op.output(0);
273+
break;
274+
}
275+
}
276+
277+
if (read == null) {
278+
read = Ops.create(graph).withSubScope("session_reads").withName(output.op().name() + "_read")
279+
.readVariableOp(output, TensorTypeRegistry.find(valueDt).type());
280+
}
281+
282+
outputs.add(read.asOutput());
283+
} else {
284+
outputs.add(output);
285+
}
240286
return this;
241287
}
242288

243289
/**
244290
* Makes {@link #run()} return the Tensor referred to by the output of {@code operand}.
245291
*
292+
* If {@code operand} is a resource variable, will fetch the value.
293+
*
246294
* @param operand the node to fetch the tensor from, as an operand
247295
* @return this session runner
248296
*/
@@ -258,9 +306,7 @@ public Runner fetch(Operand<?> operand) {
258306
* @throws IllegalArgumentException if no operation exists with the provided name
259307
*/
260308
public Runner addTarget(String operation) {
261-
GraphOperation op = graph.operationOrThrow(operation);
262-
targets.add(op);
263-
return this;
309+
return addTarget(graph.operationOrThrow(operation));
264310
}
265311

266312
/**
@@ -269,13 +315,12 @@ public Runner addTarget(String operation) {
269315
* @param operation the operation to execute
270316
* @return this session runner
271317
* @throws IllegalArgumentException if the operation is not a {@link GraphOperation}
318+
* @throws IllegalStateException if the operation is not from the session's graph.
272319
*/
273320
public Runner addTarget(Operation operation) {
274-
if (!(operation instanceof GraphOperation)) {
275-
throw new IllegalArgumentException(
276-
"Operation of type "
277-
+ operation.getClass().getName()
278-
+ " is not supported in graph sessions");
321+
if (operation.env() != graph) {
322+
throw new IllegalStateException("Can't target operation " + operation + ", it is from " +
323+
(operation.env().isEager() ? "an eager session" : "a different graph") + ".");
279324
}
280325
targets.add((GraphOperation) operation);
281326
return this;
@@ -594,12 +639,12 @@ private static void delete(TF_Session handle) {
594639
*
595640
* @param handle to the C API TF_Session object (Session.nativeHandle)
596641
* @param runOptions A RunOptions protocol buffer, or null
597-
* @param inputOpHandles (see inputOpIndices)
598-
* @param inputOpIndices (see inputTensorHandles)
599642
* @param inputTensorHandles together with inputOpHandles and inputOpIndices specifies the values that are being "fed"
600643
* (do not need to be computed) during graph execution. inputTensorHandles[i] (which corresponds to a
601644
* Tensor.nativeHandle) is considered to be the inputOpIndices[i]-th output of the Operation inputOpHandles[i]. Thus,
602645
* it is required that inputOpHandles.length == inputOpIndices.length == inputTensorHandles.length.
646+
* @param inputOpHandles (see inputOpIndices)
647+
* @param inputOpIndices (see inputTensorHandles)
603648
* @param outputOpHandles (see outputOpIndices)
604649
* @param outputOpIndices together with outputOpHandles identifies the set of values that should be computed. The
605650
* outputOpIndices[i]-th output of the Operation outputOpHandles[i], It is required that outputOpHandles.length ==

tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@
2727
import java.nio.file.Path;
2828
import java.nio.file.Paths;
2929
import java.util.Collections;
30-
import java.util.Map;
3130
import java.util.HashMap;
31+
import java.util.Map;
3232
import org.junit.jupiter.api.Test;
3333
import org.tensorflow.exceptions.TensorFlowException;
3434
import org.tensorflow.ndarray.FloatNdArray;
@@ -292,21 +292,29 @@ public void pythonTfFunction() {
292292
ConcreteFunction add = bundle.function("add");
293293
Map<String, Tensor> args = new HashMap();
294294
try (TFloat32 a = TFloat32.scalarOf(10.0f);
295-
TFloat32 b = TFloat32.scalarOf(15.5f)) {
295+
TFloat32 b = TFloat32.scalarOf(15.5f)) {
296296
args.put("a", a);
297297
args.put("b", b);
298298
Map<String, Tensor> result = add.call(args);
299299
assertEquals(result.size(), 1);
300-
try (TFloat32 c = (TFloat32)result.values().iterator().next()) {
300+
try (TFloat32 c = (TFloat32) result.values().iterator().next()) {
301301
assertEquals(25.5f, c.getFloat());
302302
}
303303
}
304+
305+
// variable unwrapping happens in Session, which is used by ConcreteFunction.call
306+
ConcreteFunction getVariable = bundle.function("get_variable");
307+
try (TFloat32 v = (TFloat32) getVariable.call(new HashMap<>())
308+
.get(getVariable.signature().outputNames().iterator().next())) {
309+
assertEquals(2f, v.getFloat());
310+
}
311+
304312
}
305313
}
306314

307315
private static Signature buildGraphWithVariables(Ops tf, Shape xShape) {
308316
Placeholder<TFloat32> x = tf.placeholder(TFloat32.class, Placeholder.shape(xShape));
309-
Variable<TFloat32> y = tf
317+
Variable<TFloat32> y = tf.withName("variable")
310318
.variable(tf.random.randomUniform(tf.constant(xShape), TFloat32.class));
311319
ReduceSum<TFloat32> z = tf.reduceSum(tf.math.add(x, y), tf.array(0, 1));
312320
Init init = tf.init();

0 commit comments

Comments
 (0)