Skip to content

Commit 620214f

Browse files
[RLlib] Fix _test_dependency_torch (#60742) (#60888)
## Description Cherry pick Fix _test_dependency_torch (#60742) into releases/2.54.0. Signed-off-by: Artur Niederfahrenhorst <artur@anyscale.com> Signed-off-by: Kamil Kaczmarek <kamil@anyscale.com> Co-authored-by: Artur Niederfahrenhorst <artur@anyscale.com>
1 parent 5d2115c commit 620214f

File tree

3 files changed

+25
-7
lines changed

3 files changed

+25
-7
lines changed

rllib/evaluation/postprocessing.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from typing import Dict, Optional
22

33
import numpy as np
4-
import scipy.signal
54

65
from ray.rllib.policy.policy import Policy
76
from ray.rllib.policy.sample_batch import SampleBatch
@@ -325,4 +324,7 @@ def discount_cumsum(x: np.ndarray, gamma: float) -> np.ndarray:
325324
2.0 + 0.9*3.0,
326325
3.0])
327326
"""
327+
# Import scipy here to avoid import error when framework is tensorflow.
328+
import scipy
329+
328330
return scipy.signal.lfilter([1], [1, float(-gamma)], x[::-1], axis=0)[::-1]

rllib/offline/offline_policy_evaluation_runner.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,6 @@
1515

1616
import ray
1717
from ray.data.iterator import DataIterator
18-
from ray.data.util.torch_utils import (
19-
convert_ndarray_batch_to_torch_tensor_batch,
20-
)
2118
from ray.rllib.connectors.env_to_module import EnvToModulePipeline
2219
from ray.rllib.core import (
2320
ALL_MODULES,
@@ -116,6 +113,12 @@ def _collate_fn(
116113
_batch: Dict[EpisodeID, Dict[str, numpy.ndarray]],
117114
) -> Dict[EpisodeID, Dict[str, TensorType]]:
118115
"""Converts a batch of episodes to torch tensors."""
116+
# Avoid torch import error when framework is tensorflow.
117+
# Note (artur): This can be removed when we remove tf support.
118+
from ray.data.util.torch_utils import (
119+
convert_ndarray_batch_to_torch_tensor_batch,
120+
)
121+
119122
return [
120123
convert_ndarray_batch_to_torch_tensor_batch(
121124
episode, device=self._device, dtypes=torch.float32

rllib/utils/tf_utils.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -892,18 +892,31 @@ def __init__(self, output, sess=None, input_variables=None):
892892
if input_variables is not None:
893893
variable_list += input_variables
894894

895+
def _get_var_name(v):
896+
"""Get variable name, supporting both TF1 ResourceVariable and
897+
Keras 3 Variable objects."""
898+
if hasattr(v, "op"):
899+
return v.op.node_def.name
900+
return v.name
901+
895902
if not tf1.executing_eagerly():
896903
for v in variable_list:
897-
self.variables[v.op.node_def.name] = v
904+
self.variables[_get_var_name(v)] = v
898905

899906
self.placeholders = {}
900907
self.assignment_nodes = {}
901908

902909
# Create new placeholders to put in custom weights.
903910
for k, var in self.variables.items():
911+
dtype = var.value().dtype if hasattr(var, "op") else var.dtype
912+
shape = (
913+
var.get_shape().as_list()
914+
if hasattr(var, "get_shape")
915+
else list(var.shape)
916+
)
904917
self.placeholders[k] = tf1.placeholder(
905-
var.value().dtype,
906-
var.get_shape().as_list(),
918+
dtype,
919+
shape,
907920
name="Placeholder_" + k,
908921
)
909922
self.assignment_nodes[k] = var.assign(self.placeholders[k])

0 commit comments

Comments
 (0)