|
4 | 4 | # LICENSE file in the root directory of this source tree. |
5 | 5 | """Integrations tests for LLVM runtime support.""" |
6 | 6 | from pathlib import Path |
| 7 | +from typing import List |
7 | 8 |
|
8 | 9 | import numpy as np |
9 | 10 | import pytest |
10 | 11 | from flaky import flaky |
11 | 12 |
|
12 | 13 | from compiler_gym.envs.llvm import LlvmEnv, llvm_benchmark |
| 14 | +from compiler_gym.spaces.reward import Reward |
| 15 | +from compiler_gym.util.gym_type_hints import ActionType, ObservationType |
13 | 16 | from tests.test_main import main |
14 | 17 |
|
15 | 18 | pytest_plugins = ["tests.pytest_plugins.llvm"] |
@@ -144,5 +147,63 @@ def test_default_runtime_observation_count_fork(env: LlvmEnv): |
144 | 147 | assert fkd.runtime_warmup_runs_count == wc |
145 | 148 |
|
146 | 149 |
|
| 150 | +class RewardDerivedFromRuntime(Reward): |
| 151 | + """A custom reward space that is derived from the Runtime observation space.""" |
| 152 | + |
| 153 | + def __init__(self): |
| 154 | + super().__init__( |
| 155 | + name="runtimeseries", |
| 156 | + observation_spaces=["Runtime"], |
| 157 | + default_value=0, |
| 158 | + min=None, |
| 159 | + max=None, |
| 160 | + default_negates_returns=True, |
| 161 | + deterministic=False, |
| 162 | + platform_dependent=True, |
| 163 | + ) |
| 164 | + self.last_runtime_observation: List[float] = None |
| 165 | + |
| 166 | + def reset(self, benchmark, observation_view) -> None: |
| 167 | + self.last_runtime_observation = observation_view["Runtime"] |
| 168 | + |
| 169 | + def update( |
| 170 | + self, |
| 171 | + actions: List[ActionType], |
| 172 | + observations: List[ObservationType], |
| 173 | + observation_view, |
| 174 | + ) -> float: |
| 175 | + del actions # unused |
| 176 | + del observation_view # unused |
| 177 | + self.last_runtime_observation = observations[0] |
| 178 | + return 0 |
| 179 | + |
| 180 | + |
| 181 | +@flaky # runtime may fail |
| 182 | +@pytest.mark.parametrize("runtime_observation_count", [1, 3, 5]) |
| 183 | +def test_correct_number_of_observations_during_reset( |
| 184 | + env: LlvmEnv, runtime_observation_count: int |
| 185 | +): |
| 186 | + env.reward.add_space(RewardDerivedFromRuntime()) |
| 187 | + env.runtime_observation_count = runtime_observation_count |
| 188 | + env.reset(reward_space="runtimeseries") |
| 189 | + assert env.runtime_observation_count == runtime_observation_count |
| 190 | + |
| 191 | + # Check that the number of observations that you are receive during reset() |
| 192 | + # matches the amount that you asked for. |
| 193 | + # FIXME(github.com/facebookresearch/CompilerGym/issues/756): This is broken. |
| 194 | + # Only a single observation is received, irrespective of how many you ask |
| 195 | + # for. |
| 196 | + assert len(env.reward.spaces["runtimeseries"].last_runtime_observation) == 1 |
| 197 | + |
| 198 | + # Check that the number of observations that you are receive during step() |
| 199 | + # matches the amount that you asked for. |
| 200 | + env.reward.spaces["runtimeseries"].last_runtime_observation = None |
| 201 | + env.step(0) |
| 202 | + assert ( |
| 203 | + len(env.reward.spaces["runtimeseries"].last_runtime_observation) |
| 204 | + == runtime_observation_count |
| 205 | + ) |
| 206 | + |
| 207 | + |
147 | 208 | if __name__ == "__main__": |
148 | 209 | main() |
0 commit comments