Skip to content

Commit 7de89c8

Browse files
committed
Add type-hints to adaptive/learner/skopt_learner.py
1 parent 1b7e84d commit 7de89c8

File tree

1 file changed

+12
-6
lines changed

1 file changed

+12
-6
lines changed

adaptive/learner/skopt_learner.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import collections
4+
from typing import Callable
45

56
import numpy as np
67
from skopt import Optimizer
@@ -25,8 +26,8 @@ class SKOptLearner(Optimizer, BaseLearner):
2526
Arguments to pass to ``skopt.Optimizer``.
2627
"""
2728

28-
def __init__(self, function, **kwargs):
29-
self.function = function
29+
def __init__(self, function: Callable, **kwargs) -> None:
30+
self.function = function # type: ignore
3031
self.pending_points = set()
3132
self.data = collections.OrderedDict()
3233
self._kwargs = kwargs
@@ -36,7 +37,7 @@ def new(self) -> SKOptLearner:
3637
"""Return a new `~adaptive.SKOptLearner` without the data."""
3738
return SKOptLearner(self.function, **self._kwargs)
3839

39-
def tell(self, x, y, fit=True):
40+
def tell(self, x: float | list[float], y: float, fit: bool = True) -> None:
4041
if isinstance(x, collections.abc.Iterable):
4142
self.pending_points.discard(tuple(x))
4243
self.data[tuple(x)] = y
@@ -55,7 +56,7 @@ def remove_unfinished(self):
5556
pass
5657

5758
@cache_latest
58-
def loss(self, real=True):
59+
def loss(self, real: bool = True) -> float:
5960
if not self.models:
6061
return np.inf
6162
else:
@@ -65,7 +66,12 @@ def loss(self, real=True):
6566
# estimator of loss, but it is the cheapest.
6667
return 1 - model.score(self.Xi, self.yi)
6768

68-
def ask(self, n, tell_pending=True):
69+
def ask(
70+
self, n: int, tell_pending: bool = True
71+
) -> (
72+
tuple[list[float], list[float]]
73+
| tuple[list[list[float]], list[float]] # XXX: this indicates a bug!
74+
):
6975
if not tell_pending:
7076
raise NotImplementedError(
7177
"Asking points is an irreversible "
@@ -79,7 +85,7 @@ def ask(self, n, tell_pending=True):
7985
return [p[0] for p in points], [self.loss() / n] * n
8086

8187
@property
82-
def npoints(self):
88+
def npoints(self) -> int:
8389
"""Number of evaluated points."""
8490
return len(self.Xi)
8591

0 commit comments

Comments
 (0)