|
5 | 5 | characteristic that they are responsible for updating some match data during
|
6 | 6 | their run, such that it contains the current state of the match.
|
7 | 7 | """
|
8 |
| -from __future__ import annotations |
9 |
| -from dataclasses import dataclass |
| 8 | +from dataclasses import dataclass, field as dataclass_field, fields |
10 | 9 | from importlib.metadata import entry_points
|
11 | 10 | import logging
|
12 | 11 | from abc import abstractmethod, ABC
|
13 |
| -from importlib import import_module |
14 |
| -from typing import Type |
| 12 | +from typing import ( |
| 13 | + Any, |
| 14 | + Callable, |
| 15 | + ClassVar, |
| 16 | + Literal, |
| 17 | + Mapping, |
| 18 | + TypeAlias, |
| 19 | + TypeVar, |
| 20 | + dataclass_transform, |
| 21 | + get_origin, |
| 22 | + get_type_hints, |
| 23 | +) |
| 24 | +from algobattle.docker_util import DockerError, Generator, Solver, GeneratorResult, SolverResult |
| 25 | +from algobattle.observer import Subject |
| 26 | +from algobattle.util import Encodable, Role |
15 | 27 |
|
16 |
| -from algobattle.fight_handler import FightHandler |
17 |
| -from algobattle.team import Matchup |
18 |
| -from algobattle.observer import Observer, Subject |
19 |
| -from algobattle.util import CLIParsable |
| 28 | +logger = logging.getLogger("algobattle.battle_wrapper") |
20 | 29 |
|
21 |
| -logger = logging.getLogger('algobattle.battle_wrapper') |
22 | 30 |
|
| 31 | +_Config: TypeAlias = Any |
| 32 | +T = TypeVar("T") |
23 | 33 |
|
24 |
| -class BattleWrapper(ABC): |
| 34 | + |
| 35 | +def argspec(*, default: T, help: str = "", parser: Callable[[str], T] | None = None) -> T: |
| 36 | + """Structure specifying the CLI arg.""" |
| 37 | + metadata = {"help": help, "parser": parser} |
| 38 | + return dataclass_field(default=default, metadata={key: val for key, val in metadata.items() if val is not None}) |
| 39 | + |
| 40 | + |
| 41 | +@dataclass |
| 42 | +class CombinedResults: |
| 43 | + """The result of one execution of the generator and the solver with the generated instance.""" |
| 44 | + |
| 45 | + score: float |
| 46 | + generator: GeneratorResult | DockerError |
| 47 | + solver: SolverResult | DockerError | None |
| 48 | + |
| 49 | + |
| 50 | +class BattleWrapper(Subject, ABC): |
25 | 51 | """Abstract Base class for wrappers that execute a specific kind of battle."""
|
26 | 52 |
|
27 |
| - @dataclass |
28 |
| - class Config(CLIParsable): |
29 |
| - """Object containing the config variables the wrapper will use.""" |
| 53 | + _wrappers: ClassVar[dict[str, type["BattleWrapper"]]] = {} |
| 54 | + |
| 55 | + scoring_team: ClassVar[Role] = "solver" |
30 | 56 |
|
31 |
| - pass |
| 57 | + @dataclass_transform(field_specifiers=(argspec,)) |
| 58 | + class Config: |
| 59 | + """Object containing the config variables the wrapper will use.""" |
32 | 60 |
|
33 |
| - _wrappers: dict[str, Type[BattleWrapper]] = {} |
| 61 | + def __init_subclass__(cls) -> None: |
| 62 | + dataclass(cls) |
| 63 | + super().__init_subclass__() |
| 64 | + |
| 65 | + # providing a dummy default impl that will be overriden, to get better static analysis |
| 66 | + def __init__(self, **kwargs) -> None: |
| 67 | + super().__init__() |
| 68 | + |
| 69 | + @classmethod |
| 70 | + def as_argparse_args(cls) -> list[tuple[str, dict[str, Any]]]: |
| 71 | + """Constructs a list of argument names and `**kwargs` that can be passed to `ArgumentParser.add_argument()`.""" |
| 72 | + arguments: list[tuple[str, dict[str, Any]]] = [] |
| 73 | + resolved_annotations = get_type_hints(cls) |
| 74 | + for field in fields(cls): |
| 75 | + kwargs = { |
| 76 | + "type": field.metadata.get("parser", resolved_annotations[field.name]), |
| 77 | + "help": field.metadata.get("help", "") + f" Default: {field.default}", |
| 78 | + } |
| 79 | + if field.type == bool: |
| 80 | + kwargs["action"] = "store_const" |
| 81 | + kwargs["const"] = not field.default |
| 82 | + elif get_origin(field.type) == Literal: |
| 83 | + kwargs["choices"] = field.type.__args__ |
| 84 | + |
| 85 | + arguments.append((field.name, kwargs)) |
| 86 | + return arguments |
34 | 87 |
|
35 | 88 | @staticmethod
|
36 |
| - def all() -> dict[str, Type[BattleWrapper]]: |
| 89 | + def all() -> dict[str, type["BattleWrapper"]]: |
37 | 90 | """Returns a list of all registered wrappers."""
|
38 | 91 | for entrypoint in entry_points(group="algobattle.wrappers"):
|
39 | 92 | if entrypoint.name not in BattleWrapper._wrappers:
|
40 |
| - wrapper: Type[BattleWrapper] = entrypoint.load() |
41 |
| - BattleWrapper._wrappers[wrapper.name()] = wrapper |
| 93 | + wrapper: type[BattleWrapper] = entrypoint.load() |
| 94 | + BattleWrapper._wrappers[wrapper.name().lower()] = wrapper |
42 | 95 | return BattleWrapper._wrappers
|
43 | 96 |
|
44 |
| - def __init_subclass__(cls) -> None: |
| 97 | + def __init_subclass__(cls, notify_var_changes: bool = False) -> None: |
45 | 98 | if cls.name() not in BattleWrapper._wrappers:
|
46 |
| - BattleWrapper._wrappers[cls.name()] = cls |
47 |
| - return super().__init_subclass__() |
| 99 | + BattleWrapper._wrappers[cls.name().lower()] = cls |
| 100 | + return super().__init_subclass__(notify_var_changes) |
48 | 101 |
|
49 |
| - @staticmethod |
50 |
| - def get_wrapper(wrapper_name: str) -> Type[BattleWrapper]: |
51 |
| - """Try to import a Battle Wrapper from a given name. |
52 |
| -
|
53 |
| - For this to work, a BattleWrapper module with the same name as the argument |
54 |
| - needs to be present in the algobattle/battle_wrappers folder. |
55 |
| -
|
56 |
| - Parameters |
57 |
| - ---------- |
58 |
| - wrapper_name : str |
59 |
| - Name of a battle wrapper module in algobattle/battle_wrappers. |
60 |
| -
|
61 |
| - Returns |
62 |
| - ------- |
63 |
| - BattleWrapper |
64 |
| - A BattleWrapper of the given wrapper_name. |
65 |
| -
|
66 |
| - Raises |
67 |
| - ------ |
68 |
| - ValueError |
69 |
| - If the wrapper does not exist in the battle_wrappers folder. |
70 |
| - """ |
71 |
| - try: |
72 |
| - wrapper_module = import_module("algobattle.battle_wrappers." + wrapper_name) |
73 |
| - return getattr(wrapper_module, wrapper_name.capitalize()) |
74 |
| - except ImportError as e: |
75 |
| - logger.critical(f"Importing a wrapper from the given path failed with the following exception: {e}") |
76 |
| - raise ValueError from e |
| 102 | + @abstractmethod |
| 103 | + def score(self) -> float: |
| 104 | + """The score achieved by the scored team during this battle.""" |
| 105 | + raise NotImplementedError |
77 | 106 |
|
78 |
| - def __init__(self, fight_handler: FightHandler, config: BattleWrapper.Config) -> None: |
79 |
| - super().__init__() |
80 |
| - self.fight_handler = fight_handler |
81 |
| - self.config = config |
| 107 | + @staticmethod |
| 108 | + def format_score(score: float) -> str: |
| 109 | + """Formats a score nicely.""" |
| 110 | + return f"{score:.2f}" |
82 | 111 |
|
83 | 112 | @abstractmethod
|
84 |
| - def run_round(self, matchup: Matchup, observer: Observer | None = None) -> BattleWrapper.Result: |
85 |
| - """Execute a full round of fights between two teams configured in the fight_handler. |
86 |
| -
|
87 |
| - During execution, the concrete BattleWrapper should update the round_data dict |
88 |
| - to which Observers can subscribe in order to react to new intermediate results. |
89 |
| - """ |
| 113 | + def display(self) -> str: |
| 114 | + """Nicely formats the object.""" |
90 | 115 | raise NotImplementedError
|
91 | 116 |
|
92 | 117 | @classmethod
|
93 | 118 | def name(cls) -> str:
|
94 | 119 | """Name of the type of this battle wrapper."""
|
95 | 120 | return cls.__name__
|
96 | 121 |
|
97 |
| - class Result(Subject): |
98 |
| - """Result of a single battle.""" |
99 |
| - |
100 |
| - @property |
101 |
| - @abstractmethod |
102 |
| - def score(self) -> float: |
103 |
| - """The score achieved by the solver of this battle.""" |
104 |
| - raise NotImplementedError |
105 |
| - |
106 |
| - @staticmethod |
107 |
| - @abstractmethod |
108 |
| - def format_score(score: float) -> str: |
109 |
| - """Formats a score nicely.""" |
110 |
| - raise NotImplementedError |
| 122 | + @abstractmethod |
| 123 | + def run_battle(self, generator: Generator, solver: Solver, config: _Config, min_size: int) -> None: |
| 124 | + """Calculates the next instance size that should be fought over.""" |
| 125 | + raise NotImplementedError |
111 | 126 |
|
112 |
| - def __str__(self) -> str: |
113 |
| - return self.format_score(self.score) |
| 127 | + def run_programs( |
| 128 | + self, |
| 129 | + generator: Generator, |
| 130 | + solver: Solver, |
| 131 | + size: int, |
| 132 | + *, |
| 133 | + timeout_generator: float | None = ..., |
| 134 | + space_generator: int | None = ..., |
| 135 | + cpus_generator: int = ..., |
| 136 | + timeout_solver: float | None = ..., |
| 137 | + space_solver: int | None = ..., |
| 138 | + cpus_solver: int = ..., |
| 139 | + generator_battle_input: Mapping[str, Encodable] = {}, |
| 140 | + solver_battle_input: Mapping[str, Encodable] = {}, |
| 141 | + generator_battle_output: Mapping[str, type[Encodable]] = {}, |
| 142 | + solver_battle_output: Mapping[str, type[Encodable]] = {}, |
| 143 | + ) -> CombinedResults: |
| 144 | + """Execute a single fight of a battle, running the generator and solver and handling any errors gracefully.""" |
| 145 | + self.notify() |
| 146 | + try: |
| 147 | + gen_result = generator.run( |
| 148 | + size=size, |
| 149 | + timeout=timeout_generator, |
| 150 | + space=space_generator, |
| 151 | + cpus=cpus_generator, |
| 152 | + battle_input=generator_battle_input, |
| 153 | + battle_output=generator_battle_output, |
| 154 | + ) |
| 155 | + except DockerError as e: |
| 156 | + return CombinedResults(score=1, generator=e, solver=None) |
114 | 157 |
|
115 |
| - @abstractmethod |
116 |
| - def display(self) -> str: |
117 |
| - """Nicely formats the object.""" |
118 |
| - raise NotImplementedError |
| 158 | + try: |
| 159 | + sol_result = solver.run( |
| 160 | + gen_result.problem, |
| 161 | + size=size, |
| 162 | + timeout=timeout_solver, |
| 163 | + space=space_solver, |
| 164 | + cpus=cpus_solver, |
| 165 | + battle_input=solver_battle_input, |
| 166 | + battle_output=solver_battle_output, |
| 167 | + ) |
| 168 | + except DockerError as e: |
| 169 | + return CombinedResults(score=0, generator=gen_result, solver=e) |
| 170 | + |
| 171 | + score = gen_result.problem.calculate_score( |
| 172 | + solution=sol_result.solution, generator_solution=gen_result.solution, size=size |
| 173 | + ) |
| 174 | + score = max(0, min(1, float(score))) |
| 175 | + logger.info(f"The solver achieved a score of {score}.") |
| 176 | + return CombinedResults(score, gen_result, sol_result) |
0 commit comments