Skip to content

Commit 5216c0e

Browse files
author
Chris Cummins
committed
[tests] Add a unit test to repro #756.
1 parent 1c40e5b commit 5216c0e

File tree

2 files changed

+63
-0
lines changed

2 files changed

+63
-0
lines changed

tests/llvm/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,8 @@ py_test(
259259
deps = [
260260
"//compiler_gym/envs/llvm",
261261
"//compiler_gym/service:connection",
262+
"//compiler_gym/spaces",
263+
"//compiler_gym/util",
262264
"//tests:test_main",
263265
"//tests/pytest_plugins:llvm",
264266
],

tests/llvm/runtime_test.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,15 @@
44
# LICENSE file in the root directory of this source tree.
55
"""Integrations tests for LLVM runtime support."""
66
from pathlib import Path
7+
from typing import List
78

89
import numpy as np
910
import pytest
1011
from flaky import flaky
1112

1213
from compiler_gym.envs.llvm import LlvmEnv, llvm_benchmark
14+
from compiler_gym.spaces.reward import Reward
15+
from compiler_gym.util.gym_type_hints import ActionType, ObservationType
1316
from tests.test_main import main
1417

1518
pytest_plugins = ["tests.pytest_plugins.llvm"]
@@ -144,5 +147,63 @@ def test_default_runtime_observation_count_fork(env: LlvmEnv):
144147
assert fkd.runtime_warmup_runs_count == wc
145148

146149

150+
class RewardDerivedFromRuntime(Reward):
151+
"""A custom reward space that is derived from the Runtime observation space."""
152+
153+
def __init__(self):
154+
super().__init__(
155+
name="runtimeseries",
156+
observation_spaces=["Runtime"],
157+
default_value=0,
158+
min=None,
159+
max=None,
160+
default_negates_returns=True,
161+
deterministic=False,
162+
platform_dependent=True,
163+
)
164+
self.last_runtime_observation: List[float] = None
165+
166+
def reset(self, benchmark, observation_view) -> None:
167+
self.last_runtime_observation = observation_view["Runtime"]
168+
169+
def update(
170+
self,
171+
actions: List[ActionType],
172+
observations: List[ObservationType],
173+
observation_view,
174+
) -> float:
175+
del actions # unused
176+
del observation_view # unused
177+
self.last_runtime_observation = observations[0]
178+
return 0
179+
180+
181+
@flaky # runtime may fail
182+
@pytest.mark.parametrize("runtime_observation_count", [1, 3, 5])
183+
def test_correct_number_of_observations_during_reset(
184+
env: LlvmEnv, runtime_observation_count: int
185+
):
186+
env.reward.add_space(RewardDerivedFromRuntime())
187+
env.runtime_observation_count = runtime_observation_count
188+
env.reset(reward_space="runtimeseries")
189+
assert env.runtime_observation_count == runtime_observation_count
190+
191+
# Check that the number of observations that you are receive during reset()
192+
# matches the amount that you asked for.
193+
# FIXME(github.com/facebookresearch/CompilerGym/issues/756): This is broken.
194+
# Only a single observation is received, irrespective of how many you ask
195+
# for.
196+
assert len(env.reward.spaces["runtimeseries"].last_runtime_observation) == 1
197+
198+
# Check that the number of observations that you are receive during step()
199+
# matches the amount that you asked for.
200+
env.reward.spaces["runtimeseries"].last_runtime_observation = None
201+
env.step(0)
202+
assert (
203+
len(env.reward.spaces["runtimeseries"].last_runtime_observation)
204+
== runtime_observation_count
205+
)
206+
207+
147208
if __name__ == "__main__":
148209
main()

0 commit comments

Comments
 (0)