Skip to content

Commit 1d91115

Browse files
committed
Added fixes based on maintainer's feedbacks
1 parent 42843db commit 1d91115

File tree

2 files changed

+10
-14
lines changed

2 files changed

+10
-14
lines changed

compiler_gym/wrappers/time_limit.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -32,18 +32,8 @@ def __init__(self, env: CompilerEnv, max_episode_steps: Optional[int] = None):
3232
self._max_episode_steps = max_episode_steps
3333
self._elapsed_steps = None
3434

35-
# def step(self, action: ActionType, **kwargs):
36-
# assert (
37-
# self._elapsed_steps is not None
38-
# ), "Cannot call env.step() before calling reset()"
39-
# observation, reward, done, info = self.env.step(action, **kwargs)
40-
# self._elapsed_steps += 1
41-
# if self._elapsed_steps >= self._max_episode_steps:
42-
# info["TimeLimit.truncated"] = not done
43-
# done = True
44-
# return observation, reward, done, info
45-
4635
def multistep(self, actions: Iterable[ActionType], **kwargs):
36+
actions = list(actions)
4737
assert (
4838
self._elapsed_steps is not None
4939
), "Cannot call env.step() before calling reset()"

tests/wrappers/time_limit_wrappers_test.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,20 +81,26 @@ def test_time_limit_fork(env: LlvmEnv):
8181
fkd.close()
8282

8383

84-
# @pytest.mark.xfail(strict=True, reason="https://github.com/facebookresearch/CompilerGym/issues/648")
8584
def test_time_limit(env: LlvmEnv):
8685
"""Check CycleOverBenchmarks does not break TimeLimit"""
87-
env = TimeLimit(env, max_episode_steps=1)
86+
env = TimeLimit(env, max_episode_steps=3)
8887
env = CycleOverBenchmarks(
8988
env,
9089
benchmarks=[
9190
"benchmark://cbench-v1/crc32",
9291
],
9392
)
9493
env.reset()
95-
_, _, done, _ = env.step(0)
9694

95+
_, _, done, info = env.step(0)
96+
assert not done, info
97+
98+
_, _, done, info = env.step(0)
99+
assert not done, info
100+
101+
_, _, done, info = env.step(0)
97102
assert done
103+
assert info["TimeLimit.truncated"], info
98104

99105

100106
if __name__ == "__main__":

0 commit comments

Comments
 (0)