-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathevaluate.py
More file actions
56 lines (44 loc) · 1.62 KB
/
evaluate.py
File metadata and controls
56 lines (44 loc) · 1.62 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import itertools
import jax
import jax.numpy as jnp
import jax.random as random
import airports
import game
import learning
import play_game
def main():
train_state = play_game.load_ai('./data/aiviator/')
batch_size = 256
for airport, airport_name in [
(airports.KEF, 'KEF'), (airports.KUL, 'KUL'), (airports.PBH, 'PBH'),
(airports.PBH_RED, 'PBH_RED'), (airports.GIG, 'GIG'), (airports.TGU, 'TGU'),
(airports.OSL, 'OSL')]:
print(airport_name)
make_state = jax.jit(jax.vmap(airport))
for selected_abilities in itertools.combinations(
['working_together', 'mastery', 'anticipation', 'control'], 2):
abilities = {
key: jnp.array([key in selected_abilities] * batch_size)
for key in ['working_together', 'mastery', 'anticipation', 'control']
}
wins, losses = 0, 0
# start = time.time()
for seed in range(0, 1024 // batch_size):
key = random.PRNGKey(seed)
key, init_key = random.split(key)
states = make_state(key=random.split(init_key, batch_size), **abilities)
scores = jnp.zeros(batch_size)
history = []
for _ in range(learning.MAX_TURNS):
# key, step_key = random.split(key)
states, scores, result = learning.trajectory_step(train_state, states, scores, None)
history.append(result)
for r in history[-1][0].result:
if r == game.WIN:
wins += 1
else:
losses += 1
print(f" {selected_abilities}: {wins}/{wins + losses} = {100 * wins / (wins + losses):0.1f}%")
#print(time.time() - start)
if __name__ == '__main__':
main()