Skip to content

Commit f97f697

Browse files
ChrisCumminsChris Cummins
authored andcommitted
[examples] Add implementations of LLVM autotuners.
1 parent 94c9e79 commit f97f697

File tree

12 files changed

+582
-1
lines changed

12 files changed

+582
-1
lines changed

compiler_gym/envs/llvm/compute_observation.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import google.protobuf.text_format
1212

1313
from compiler_gym.service.proto import Observation
14+
from compiler_gym.util.gym_type_hints import ObservationType
1415
from compiler_gym.util.runfiles_path import runfiles_path
1516
from compiler_gym.util.shell_format import plural
1617
from compiler_gym.views.observation_space_spec import ObservationSpaceSpec
@@ -35,7 +36,7 @@ def pascal_case_to_enum(pascal_case: str) -> str:
3536

3637
def compute_observation(
3738
observation_space: ObservationSpaceSpec, bitcode: Path, timeout: float = 300
38-
):
39+
) -> ObservationType:
3940
"""Compute an LLVM observation.
4041
4142
This is a utility function that uses a standalone C++ binary to compute an
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
"""This modules defines a class for describing LLVM autotuners."""
6+
import tempfile
7+
from pathlib import Path
8+
from typing import Any, Dict
9+
10+
from llvm_autotuning.autotuners.greedy import greedy # noqa autotuner
11+
from llvm_autotuning.autotuners.nevergrad_ import nevergrad # noqa autotuner
12+
from llvm_autotuning.autotuners.opentuner_ import opentuner_ga # noqa autotuner
13+
from llvm_autotuning.autotuners.random_ import random # noqa autotuner
14+
from llvm_autotuning.optimization_target import OptimizationTarget
15+
from pydantic import BaseModel, validator
16+
17+
from compiler_gym.compiler_env_state import CompilerEnvState
18+
from compiler_gym.envs import CompilerEnv
19+
from compiler_gym.util.capture_output import capture_output
20+
from compiler_gym.util.runfiles_path import transient_cache_path
21+
from compiler_gym.util.temporary_working_directory import temporary_working_directory
22+
from compiler_gym.util.timer import Timer
23+
24+
25+
class Autotuner(BaseModel):
26+
27+
algorithm: str
28+
"""The name of the autotuner algorithm."""
29+
30+
optimization_target: OptimizationTarget
31+
"""The target that the autotuner is optimizing for."""
32+
33+
search_time_seconds: int
34+
"""The search budget of the autotuner."""
35+
36+
algorithm_config: Dict[str, Any] = {}
37+
"""An optional dictionary of keyword arguments for the autotuner function."""
38+
39+
@property
40+
def autotune(self):
41+
"""Return the autotuner function for this algorithm.
42+
43+
An autotuner function takes a single CompilerEnv argument and optional
44+
keyword configuration arguments (determined by algorithm_config) and
45+
tunes the environment, returning nothing.
46+
"""
47+
try:
48+
return globals()[self.algorithm]
49+
except KeyError as e:
50+
raise ValueError(
51+
f"Unknown autotuner: {self.algorithm}.\n"
52+
f"Make sure the {self.algorithm}() function definition is available "
53+
"in the global namespace of {__file__}."
54+
) from e
55+
56+
@property
57+
def autotune_kwargs(self) -> Dict[str, Any]:
58+
"""Get the keyword arguments dictionary for the autotuner."""
59+
kwargs = {
60+
"optimization_target": self.optimization_target,
61+
"search_time_seconds": self.search_time_seconds,
62+
}
63+
kwargs.update(self.algorithm_config)
64+
return kwargs
65+
66+
def __call__(self, env: CompilerEnv, seed: int = 0xCC) -> CompilerEnvState:
67+
"""Autotune the given environment.
68+
69+
:param env: The environment to autotune.
70+
71+
:param seed: The random seed for the autotuner.
72+
73+
:returns: A CompilerEnvState tuple describing the autotuning result.
74+
"""
75+
# Run the autotuner in a temporary working directory and capture the
76+
# stdout/stderr.
77+
with tempfile.TemporaryDirectory(
78+
dir=transient_cache_path("."), prefix="autotune-"
79+
) as tmpdir:
80+
with temporary_working_directory(Path(tmpdir)):
81+
with capture_output():
82+
with Timer() as timer:
83+
self.autotune(env, seed=seed, **self.autotune_kwargs)
84+
85+
return CompilerEnvState(
86+
benchmark=env.benchmark.uri,
87+
commandline=env.commandline(),
88+
walltime=timer.time,
89+
reward=self.optimization_target.final_reward(env),
90+
)
91+
92+
# === Start of implementation details. ===
93+
94+
@validator("algorithm_config", pre=True)
95+
def validate_algorithm_config(cls, value) -> Dict[str, Any]:
96+
return value or {}
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
from time import time
6+
7+
8+
def greedy(env, search_time_seconds: int, **kwargs) -> None:
9+
"""A greedy search policy.
10+
11+
At each step, the policy evaluates all possible actions and selects the
12+
action with the highest reward. The search stops when no action produces a
13+
positive reward.
14+
15+
:param env: The environment to optimize.
16+
"""
17+
18+
def eval_action(env, action: int):
19+
with env.fork() as fkd:
20+
return (fkd.step(action)[1], action)
21+
22+
end_time = time() + search_time_seconds
23+
while time() < end_time:
24+
best = max(eval_action(env, action) for action in range(env.action_space.n))
25+
if best[0] <= 0 or env.step(best[1])[2]:
26+
return
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
"""Integration tests for the LLVM autotuners."""
6+
import pytest
7+
from llvm_autotuning.autotuners import Autotuner
8+
9+
import compiler_gym
10+
11+
12+
@pytest.mark.skip(reason="greedy takes a long time")
13+
def test_autotune():
14+
with compiler_gym.make("llvm-v0", reward_space="IrInstructionCount") as env:
15+
env.reset(benchmark="benchmark://cbench-v1/crc32")
16+
17+
autotuner = Autotuner(
18+
algorithm="greedy",
19+
optimization_target="codesize",
20+
search_time_seconds=3,
21+
)
22+
23+
result = autotuner(env)
24+
print(result)
25+
assert result.benchmark == "benchmark://cbench-v1/crc32"
26+
assert result.walltime >= 3
27+
assert result.commandline == env.commandline()
28+
assert env.episode_reward
29+
assert env.benchmark == "benchmark://cbench-v1/crc32"
30+
assert env.reward_space == "IrInstructionCount"
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
from functools import lru_cache
6+
from time import time
7+
from typing import Tuple
8+
9+
import nevergrad as ng
10+
from llvm_autotuning.optimization_target import OptimizationTarget
11+
12+
from compiler_gym.envs import CompilerEnv
13+
14+
15+
def nevergrad(
16+
env: CompilerEnv,
17+
optimization_target: OptimizationTarget,
18+
search_time_seconds: int,
19+
seed: int,
20+
episode_length: int = 100,
21+
optimizer: str = "DiscreteLenglerOnePlusOne",
22+
**kwargs
23+
) -> None:
24+
"""Optimize an environment using nevergrad.
25+
26+
Nevergrad is a gradient-free optimization platform that provides
27+
implementations of various black box optimizations techniques:
28+
29+
https://facebookresearch.github.io/nevergrad/
30+
"""
31+
if optimization_target == OptimizationTarget.RUNTIME:
32+
33+
def calculate_negative_reward(actions: Tuple[int]) -> float:
34+
env.reset()
35+
env.step(actions)
36+
return -env.episode_reward
37+
38+
else:
39+
# Only cache the deterministic non-runtime rewards.
40+
@lru_cache(maxsize=int(1e4))
41+
def calculate_negative_reward(actions: Tuple[int]) -> float:
42+
env.reset()
43+
env.step(actions)
44+
return -env.episode_reward
45+
46+
params = ng.p.Choice(
47+
choices=range(env.action_space.n),
48+
repetitions=episode_length,
49+
deterministic=True,
50+
)
51+
params.random_state.seed(seed)
52+
53+
optimizer_class = getattr(ng.optimizers, optimizer)
54+
optimizer = optimizer_class(parametrization=params, budget=1, num_workers=1)
55+
56+
end_time = time() + search_time_seconds
57+
while time() < end_time:
58+
x = optimizer.ask()
59+
optimizer.tell(x, calculate_negative_reward(x.value))
60+
61+
# Get best solution and replay it.
62+
recommendation = optimizer.provide_recommendation()
63+
env.reset()
64+
env.step(recommendation.value)
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
"""Integration tests for the LLVM autotuners."""
6+
from llvm_autotuning.autotuners import Autotuner
7+
8+
import compiler_gym
9+
10+
11+
def test_autotune():
12+
with compiler_gym.make("llvm-v0", reward_space="IrInstructionCount") as env:
13+
env.reset(benchmark="benchmark://cbench-v1/crc32")
14+
env.reward_space = "IrInstructionCount"
15+
16+
autotuner = Autotuner(
17+
algorithm="nevergrad",
18+
optimization_target="codesize",
19+
search_time_seconds=3,
20+
)
21+
22+
result = autotuner(env)
23+
print(result)
24+
assert result.benchmark == "benchmark://cbench-v1/crc32"
25+
assert result.walltime >= 3
26+
assert result.commandline == env.commandline()
27+
assert env.episode_reward >= 0
28+
assert env.benchmark == "benchmark://cbench-v1/crc32"
29+
assert env.reward_space == "IrInstructionCount"

0 commit comments

Comments
 (0)