Skip to content

Commit 9c1b26e

Browse files
authored
Revert "Add fetchVariable method to Session to get value of resource variable (#261)"
This reverts commit e229028.
1 parent e229028 commit 9c1b26e

File tree

7 files changed

+60
-170
lines changed

7 files changed

+60
-170
lines changed

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java

Lines changed: 16 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616
package org.tensorflow;
1717

1818
import static org.tensorflow.Graph.resolveOutputs;
19-
import static org.tensorflow.internal.c_api.global.tensorflow.TF_OperationGetAttrType;
19+
import static org.tensorflow.internal.c_api.global.tensorflow.TF_CloseSession;
20+
import static org.tensorflow.internal.c_api.global.tensorflow.TF_DeleteSession;
21+
import static org.tensorflow.internal.c_api.global.tensorflow.TF_NewSession;
2022
import static org.tensorflow.internal.c_api.global.tensorflow.TF_SessionRun;
2123
import static org.tensorflow.internal.c_api.global.tensorflow.TF_SetConfig;
2224

@@ -36,12 +38,8 @@
3638
import org.tensorflow.internal.c_api.TF_SessionOptions;
3739
import org.tensorflow.internal.c_api.TF_Status;
3840
import org.tensorflow.internal.c_api.TF_Tensor;
39-
import org.tensorflow.internal.types.registry.TensorTypeRegistry;
4041
import org.tensorflow.op.Op;
41-
import org.tensorflow.op.Ops;
42-
import org.tensorflow.op.core.ReadVariableOp;
4342
import org.tensorflow.proto.framework.ConfigProto;
44-
import org.tensorflow.proto.framework.DataType;
4543
import org.tensorflow.proto.framework.RunMetadata;
4644
import org.tensorflow.proto.framework.RunOptions;
4745
import org.tensorflow.proto.util.SaverDef;
@@ -194,11 +192,6 @@ public Runner feed(String operation, int index, Tensor t) {
194192
* @return this session runner
195193
*/
196194
public Runner feed(Operand<?> operand, Tensor t) {
197-
if (operand.env() != graph) {
198-
throw new IllegalStateException("Can't feed value for operand " + operand + ", it is from " +
199-
(operand.env().isEager() ? "an eager session" : "a different graph") + ".");
200-
}
201-
202195
inputs.add(operand.asOutput());
203196
inputTensors.add(t);
204197
return this;
@@ -207,8 +200,6 @@ public Runner feed(Operand<?> operand, Tensor t) {
207200
/**
208201
* Make {@link #run()} return the output of {@code operation}.
209202
*
210-
* If the output is a resource variable, will fetch the value.
211-
*
212203
* @param operation Is either the string name of the operation, in which case this method is a shorthand for {@code
213204
* fetch(operation, 0)}, or it is a string of the form
214205
* <tt>operation_name:output_index</tt> , in which case this method acts like {@code
@@ -224,8 +215,6 @@ public Runner fetch(String operation) {
224215
/**
225216
* Make {@link #run()} return the {@code index}-th output of {@code operation}.
226217
*
227-
* If the output is a resource variable, will fetch the value.
228-
*
229218
* <p>Operations in a {@link Graph} can have multiple outputs, {@code index} identifies which
230219
* one to return.
231220
*
@@ -236,61 +225,24 @@ public Runner fetch(String operation) {
236225
*/
237226
public Runner fetch(String operation, int index) {
238227
Operation op = graph.operationOrThrow(operation);
239-
return fetch(op.output(index));
228+
outputs.add(op.output(index));
229+
return this;
240230
}
241231

242232
/**
243233
* Makes {@link #run()} return the Tensor referred to by {@code output}.
244234
*
245-
* If {@code output} is a resource variable, will fetch the value.
246-
*
247235
* @param output the node to fetch the tensor from
248236
* @return this session runner
249237
*/
250238
public Runner fetch(Output<?> output) {
251-
if (output.env() != graph) {
252-
throw new IllegalStateException("Can't fetch output " + output + ", it is from " +
253-
(output.env().isEager() ? "an eager session" : "a different graph") + ".");
254-
}
255-
256-
if (output.dataType() == DataType.DT_RESOURCE) {
257-
int[] rawDt = new int[1];
258-
259-
GraphOperation graphOp = (GraphOperation) output.op();
260-
261-
try (PointerScope scope = new PointerScope()) {
262-
TF_Status status = TF_Status.newStatus();
263-
TF_OperationGetAttrType(graphOp.getUnsafeNativeHandle(), "dtype", rawDt, status);
264-
status.throwExceptionIfNotOK();
265-
}
266-
267-
DataType valueDt = DataType.forNumber(rawDt[0]);
268-
269-
Operand<?> read = null;
270-
for (GraphOperation op : graphOp.consumers()) {
271-
if (op.dtype(0) == valueDt && op.type().equals(ReadVariableOp.OP_NAME)) {
272-
read = op.output(0);
273-
break;
274-
}
275-
}
276-
277-
if (read == null) {
278-
read = Ops.create(graph).withSubScope("session_reads").withName(output.op().name() + "_read")
279-
.readVariableOp(output, TensorTypeRegistry.find(valueDt).type());
280-
}
281-
282-
outputs.add(read.asOutput());
283-
} else {
284-
outputs.add(output);
285-
}
239+
outputs.add(output);
286240
return this;
287241
}
288242

289243
/**
290244
* Makes {@link #run()} return the Tensor referred to by the output of {@code operand}.
291245
*
292-
* If {@code operand} is a resource variable, will fetch the value.
293-
*
294246
* @param operand the node to fetch the tensor from, as an operand
295247
* @return this session runner
296248
*/
@@ -306,7 +258,9 @@ public Runner fetch(Operand<?> operand) {
306258
* @throws IllegalArgumentException if no operation exists with the provided name
307259
*/
308260
public Runner addTarget(String operation) {
309-
return addTarget(graph.operationOrThrow(operation));
261+
GraphOperation op = graph.operationOrThrow(operation);
262+
targets.add(op);
263+
return this;
310264
}
311265

312266
/**
@@ -315,12 +269,13 @@ public Runner addTarget(String operation) {
315269
* @param operation the operation to execute
316270
* @return this session runner
317271
* @throws IllegalArgumentException if the operation is not a {@link GraphOperation}
318-
* @throws IllegalStateException if the operation is not from the session's graph.
319272
*/
320273
public Runner addTarget(Operation operation) {
321-
if (operation.env() != graph) {
322-
throw new IllegalStateException("Can't target operation " + operation + ", it is from " +
323-
(operation.env().isEager() ? "an eager session" : "a different graph") + ".");
274+
if (!(operation instanceof GraphOperation)) {
275+
throw new IllegalArgumentException(
276+
"Operation of type "
277+
+ operation.getClass().getName()
278+
+ " is not supported in graph sessions");
324279
}
325280
targets.add((GraphOperation) operation);
326281
return this;
@@ -639,12 +594,12 @@ private static void delete(TF_Session handle) {
639594
*
640595
* @param handle to the C API TF_Session object (Session.nativeHandle)
641596
* @param runOptions A RunOptions protocol buffer, or null
597+
* @param inputOpHandles (see inputOpIndices)
598+
* @param inputOpIndices (see inputTensorHandles)
642599
* @param inputTensorHandles together with inputOpHandles and inputOpIndices specifies the values that are being "fed"
643600
* (do not need to be computed) during graph execution. inputTensorHandles[i] (which corresponds to a
644601
* Tensor.nativeHandle) is considered to be the inputOpIndices[i]-th output of the Operation inputOpHandles[i]. Thus,
645602
* it is required that inputOpHandles.length == inputOpIndices.length == inputTensorHandles.length.
646-
* @param inputOpHandles (see inputOpIndices)
647-
* @param inputOpIndices (see inputTensorHandles)
648603
* @param outputOpHandles (see outputOpIndices)
649604
* @param outputOpIndices together with outputOpHandles identifies the set of values that should be computed. The
650605
* outputOpIndices[i]-th output of the Operation outputOpHandles[i], It is required that outputOpHandles.length ==

tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@
2727
import java.nio.file.Path;
2828
import java.nio.file.Paths;
2929
import java.util.Collections;
30-
import java.util.HashMap;
3130
import java.util.Map;
31+
import java.util.HashMap;
3232
import org.junit.jupiter.api.Test;
3333
import org.tensorflow.exceptions.TensorFlowException;
3434
import org.tensorflow.ndarray.FloatNdArray;
@@ -292,29 +292,21 @@ public void pythonTfFunction() {
292292
ConcreteFunction add = bundle.function("add");
293293
Map<String, Tensor> args = new HashMap();
294294
try (TFloat32 a = TFloat32.scalarOf(10.0f);
295-
TFloat32 b = TFloat32.scalarOf(15.5f)) {
295+
TFloat32 b = TFloat32.scalarOf(15.5f)) {
296296
args.put("a", a);
297297
args.put("b", b);
298298
Map<String, Tensor> result = add.call(args);
299299
assertEquals(result.size(), 1);
300-
try (TFloat32 c = (TFloat32) result.values().iterator().next()) {
300+
try (TFloat32 c = (TFloat32)result.values().iterator().next()) {
301301
assertEquals(25.5f, c.getFloat());
302302
}
303303
}
304-
305-
// variable unwrapping happens in Session, which is used by ConcreteFunction.call
306-
ConcreteFunction getVariable = bundle.function("get_variable");
307-
try (TFloat32 v = (TFloat32) getVariable.call(new HashMap<>())
308-
.get(getVariable.signature().outputNames().iterator().next())) {
309-
assertEquals(2f, v.getFloat());
310-
}
311-
312304
}
313305
}
314306

315307
private static Signature buildGraphWithVariables(Ops tf, Shape xShape) {
316308
Placeholder<TFloat32> x = tf.placeholder(TFloat32.class, Placeholder.shape(xShape));
317-
Variable<TFloat32> y = tf.withName("variable")
309+
Variable<TFloat32> y = tf
318310
.variable(tf.random.randomUniform(tf.constant(xShape), TFloat32.class));
319311
ReduceSum<TFloat32> z = tf.reduceSum(tf.math.add(x, y), tf.array(0, 1));
320312
Init init = tf.init();

0 commit comments

Comments
 (0)