1
1
from __future__ import annotations
2
2
3
3
import collections
4
+ from typing import Callable
4
5
5
6
import numpy as np
6
7
from skopt import Optimizer
@@ -25,8 +26,8 @@ class SKOptLearner(Optimizer, BaseLearner):
25
26
Arguments to pass to ``skopt.Optimizer``.
26
27
"""
27
28
28
- def __init__ (self , function , ** kwargs ):
29
- self .function = function
29
+ def __init__ (self , function : Callable , ** kwargs ) -> None :
30
+ self .function = function # type: ignore
30
31
self .pending_points = set ()
31
32
self .data = collections .OrderedDict ()
32
33
self ._kwargs = kwargs
@@ -36,7 +37,7 @@ def new(self) -> SKOptLearner:
36
37
"""Return a new `~adaptive.SKOptLearner` without the data."""
37
38
return SKOptLearner (self .function , ** self ._kwargs )
38
39
39
- def tell (self , x , y , fit = True ):
40
+ def tell (self , x : float | list [ float ] , y : float , fit : bool = True ) -> None :
40
41
if isinstance (x , collections .abc .Iterable ):
41
42
self .pending_points .discard (tuple (x ))
42
43
self .data [tuple (x )] = y
@@ -55,7 +56,7 @@ def remove_unfinished(self):
55
56
pass
56
57
57
58
@cache_latest
58
- def loss (self , real = True ):
59
+ def loss (self , real : bool = True ) -> float :
59
60
if not self .models :
60
61
return np .inf
61
62
else :
@@ -65,7 +66,12 @@ def loss(self, real=True):
65
66
# estimator of loss, but it is the cheapest.
66
67
return 1 - model .score (self .Xi , self .yi )
67
68
68
- def ask (self , n , tell_pending = True ):
69
+ def ask (
70
+ self , n : int , tell_pending : bool = True
71
+ ) -> (
72
+ tuple [list [float ], list [float ]]
73
+ | tuple [list [list [float ]], list [float ]] # XXX: this indicates a bug!
74
+ ):
69
75
if not tell_pending :
70
76
raise NotImplementedError (
71
77
"Asking points is an irreversible "
@@ -79,7 +85,7 @@ def ask(self, n, tell_pending=True):
79
85
return [p [0 ] for p in points ], [self .loss () / n ] * n
80
86
81
87
@property
82
- def npoints (self ):
88
+ def npoints (self ) -> int :
83
89
"""Number of evaluated points."""
84
90
return len (self .Xi )
85
91
0 commit comments