Skip to content

Commit 999f3c3

Browse files
authored
[ADD] Stricter checks mypy (#240)
* Fix mypy errors * Fix error in dask import * Fix flake * Attempt to fix mypy errors * Last error * Fix flake after rebase * Address comments from shuhei * Test updated * Address comments from Fransisco, make BaseDatasetPropertiesType * fix torch tensor import * Fix tests from dataset properties and address comments from shuhei * Fix tests * Fix mypy after rebase * fix flake * debug * increase patience for early stopping * change seed of pipeline * change precision * fix doc discrepancy
1 parent 0771dce commit 999f3c3

File tree

101 files changed

+622
-459
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

101 files changed

+622
-459
lines changed

.pre-commit-config.yaml

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,25 @@ repos:
33
rev: v0.761
44
hooks:
55
- id: mypy
6-
args: [--show-error-codes]
7-
name: mypy AutoPyTorch
6+
args: [--show-error-codes,
7+
--warn-redundant-casts,
8+
--warn-return-any,
9+
--warn-unreachable,
10+
]
811
files: autoPyTorch/.*
12+
exclude: autoPyTorch/ensemble/
913
- repo: https://gitlab.com/pycqa/flake8
1014
rev: 3.8.3
1115
hooks:
1216
- id: flake8
13-
name: flake8 AutoPyTorch
14-
files: autoPyTorch/.*
1517
additional_dependencies:
1618
- flake8-print==3.1.4
1719
- flake8-import-order
20+
name: flake8 autoPyTorch
21+
files: autoPyTorch/.*
1822
- id: flake8
19-
name: flake8 tests
20-
files: test/.*
2123
additional_dependencies:
2224
- flake8-print==3.1.4
2325
- flake8-import-order
26+
name: flake8 test
27+
files: test/.*

autoPyTorch/api/base_task.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,12 @@
1212
import unittest.mock
1313
import warnings
1414
from abc import abstractmethod
15-
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast
15+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
1616

1717
from ConfigSpace.configuration_space import Configuration, ConfigurationSpace
1818

1919
import dask
20+
import dask.distributed
2021

2122
import joblib
2223

@@ -38,7 +39,6 @@
3839
from autoPyTorch.datasets.base_dataset import BaseDataset
3940
from autoPyTorch.datasets.resampling_strategy import CrossValTypes, HoldoutValTypes
4041
from autoPyTorch.ensemble.ensemble_builder import EnsembleBuilderManager
41-
from autoPyTorch.ensemble.ensemble_selection import EnsembleSelection
4242
from autoPyTorch.ensemble.singlebest_ensemble import SingleBest
4343
from autoPyTorch.evaluation.abstract_evaluator import fit_and_suppress_warnings
4444
from autoPyTorch.evaluation.tae import ExecuteTaFuncWithQueue, get_cost_of_crash
@@ -198,7 +198,7 @@ def __init__(
198198
# examples. Nevertheless, multi-process runs
199199
# have spawn as requirement to reduce the
200200
# possibility of a deadlock
201-
self._dask_client = None
201+
self._dask_client: Optional[dask.distributed.Client] = None
202202
self._multiprocessing_context = 'forkserver'
203203
if self.n_jobs == 1:
204204
self._multiprocessing_context = 'fork'
@@ -711,7 +711,8 @@ def _search(
711711
precision: int = 32,
712712
disable_file_output: List = [],
713713
load_models: bool = True,
714-
portfolio_selection: Optional[str] = None
714+
portfolio_selection: Optional[str] = None,
715+
dask_client: Optional[dask.distributed.Client] = None
715716
) -> 'BaseTask':
716717
"""
717718
Search for the best pipeline configuration for the given dataset.
@@ -857,10 +858,11 @@ def _search(
857858
# If no dask client was provided, we create one, so that we can
858859
# start a ensemble process in parallel to smbo optimize
859860
if (
860-
self._dask_client is None and (self.ensemble_size > 0 or self.n_jobs is not None and self.n_jobs > 1)
861+
dask_client is None and (self.ensemble_size > 0 or self.n_jobs > 1)
861862
):
862863
self._create_dask_client()
863864
else:
865+
self._dask_client = dask_client
864866
self._is_dask_client_internally_created = False
865867

866868
# Handle time resource allocation
@@ -1206,7 +1208,6 @@ def predict(
12061208

12071209
# Mypy assert
12081210
assert self.ensemble_ is not None, "Load models should error out if no ensemble"
1209-
self.ensemble_ = cast(Union[SingleBest, EnsembleSelection], self.ensemble_)
12101211

12111212
if isinstance(self.resampling_strategy, HoldoutValTypes):
12121213
models = self.models_
@@ -1315,15 +1316,17 @@ def get_models_with_weights(self) -> List:
13151316
self._load_models()
13161317

13171318
assert self.ensemble_ is not None
1318-
return self.ensemble_.get_models_with_weights(self.models_)
1319+
models_with_weights: List[Tuple[float, BasePipeline]] = self.ensemble_.get_models_with_weights(self.models_)
1320+
return models_with_weights
13191321

13201322
def show_models(self) -> str:
13211323
df = []
13221324
for weight, model in self.get_models_with_weights():
13231325
representation = model.get_pipeline_representation()
13241326
representation.update({'Weight': weight})
13251327
df.append(representation)
1326-
return pd.DataFrame(df).to_markdown()
1328+
models_markdown: str = pd.DataFrame(df).to_markdown()
1329+
return models_markdown
13271330

13281331
def _print_debug_info_to_log(self) -> None:
13291332
"""

autoPyTorch/data/base_target_validator.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,6 @@ def fit(
9595
np.shape(y_test)
9696
))
9797
if isinstance(y_train, pd.DataFrame):
98-
y_train = typing.cast(pd.DataFrame, y_train)
9998
y_test = typing.cast(pd.DataFrame, y_test)
10099
if y_train.columns.tolist() != y_test.columns.tolist():
101100
raise ValueError(

autoPyTorch/data/tabular_feature_validator.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,6 @@ def transform(
145145
X = self.numpy_array_to_pandas(X)
146146

147147
if hasattr(X, "iloc") and not scipy.sparse.issparse(X):
148-
X = typing.cast(pd.DataFrame, X)
149148
if np.any(pd.isnull(X)):
150149
for column in X.columns:
151150
if X[column].isna().all():

autoPyTorch/data/tabular_target_validator.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -194,8 +194,9 @@ def _check_data(
194194
A set of features whose dimensionality and data type is going to be checked
195195
"""
196196

197-
if not isinstance(
198-
y, (np.ndarray, pd.DataFrame, list, pd.Series)) and not scipy.sparse.issparse(y):
197+
if not isinstance(y, (np.ndarray, pd.DataFrame,
198+
typing.List, pd.Series)) \
199+
and not scipy.sparse.issparse(y): # type: ignore[misc]
199200
raise ValueError("AutoPyTorch only supports Numpy arrays, Pandas DataFrames,"
200201
" pd.Series, sparse data and Python Lists as targets, yet, "
201202
"the provided input is of type {}".format(

autoPyTorch/datasets/base_dataset.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from autoPyTorch.utils.common import FitRequirement
2727

2828
BaseDatasetInputType = Union[Tuple[np.ndarray, np.ndarray], Dataset]
29+
BaseDatasetPropertiesType = Union[int, float, str, List, bool]
2930

3031

3132
def check_valid_data(data: Any) -> None:
@@ -125,7 +126,6 @@ def __init__(
125126
self.task_type: Optional[str] = None
126127
self.issparse: bool = issparse(self.train_tensors[0])
127128
self.input_shape: Tuple[int] = self.train_tensors[0].shape[1:]
128-
129129
if len(self.train_tensors) == 2 and self.train_tensors[1] is not None:
130130
self.output_type: str = type_of_target(self.train_tensors[1])
131131

@@ -205,7 +205,7 @@ def __getitem__(self, index: int, train: bool = True) -> Tuple[np.ndarray, ...]:
205205
return X, Y
206206

207207
def __len__(self) -> int:
208-
return self.train_tensors[0].shape[0]
208+
return int(self.train_tensors[0].shape[0])
209209

210210
def _get_indices(self) -> np.ndarray:
211211
return self.random_state.permutation(len(self)) if self.shuffle else np.arange(len(self))
@@ -349,7 +349,9 @@ def replace_data(self, X_train: BaseDatasetInputType,
349349
self.test_tensors = (X_test, self.test_tensors[1])
350350
return self
351351

352-
def get_dataset_properties(self, dataset_requirements: List[FitRequirement]) -> Dict[str, Any]:
352+
def get_dataset_properties(
353+
self, dataset_requirements: List[FitRequirement]
354+
) -> Dict[str, BaseDatasetPropertiesType]:
353355
"""
354356
Gets the dataset properties required in the fit dictionary.
355357
This depends on the components that are active in the
@@ -364,7 +366,7 @@ def get_dataset_properties(self, dataset_requirements: List[FitRequirement]) ->
364366
<https://github.com/automl/Auto-PyTorch/blob/refactor_development/autoPyTorch/utils/pipeline.py#L25>`
365367
366368
Returns:
367-
dataset_properties (Dict[str, Any]):
369+
dataset_properties (Dict[str, BaseDatasetPropertiesType]):
368370
Dict of the dataset properties.
369371
"""
370372
dataset_properties = dict()
@@ -376,11 +378,11 @@ def get_dataset_properties(self, dataset_requirements: List[FitRequirement]) ->
376378
dataset_properties.update(self.get_required_dataset_info())
377379
return dataset_properties
378380

379-
def get_required_dataset_info(self) -> Dict[str, Any]:
381+
def get_required_dataset_info(self) -> Dict[str, BaseDatasetPropertiesType]:
380382
"""
381383
Returns a dictionary containing required dataset
382384
properties to instantiate a pipeline.
383385
"""
384-
info = {'output_type': self.output_type,
385-
'issparse': self.issparse}
386+
info: Dict[str, BaseDatasetPropertiesType] = {'output_type': self.output_type,
387+
'issparse': self.issparse}
386388
return info

autoPyTorch/datasets/tabular_dataset.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
TASK_TYPES_TO_STRING,
1818
)
1919
from autoPyTorch.data.base_validator import BaseInputValidator
20-
from autoPyTorch.datasets.base_dataset import BaseDataset
20+
from autoPyTorch.datasets.base_dataset import BaseDataset, BaseDatasetPropertiesType
2121
from autoPyTorch.datasets.resampling_strategy import (
2222
CrossValTypes,
2323
HoldoutValTypes,
@@ -98,7 +98,7 @@ def __init__(self,
9898
if STRING_TO_TASK_TYPES[self.task_type] in CLASSIFICATION_TASKS:
9999
self.num_classes: int = len(np.unique(self.train_tensors[1]))
100100

101-
def get_required_dataset_info(self) -> Dict[str, Any]:
101+
def get_required_dataset_info(self) -> Dict[str, BaseDatasetPropertiesType]:
102102
"""
103103
Returns a dictionary containing required dataset
104104
properties to instantiate a pipeline.
@@ -120,6 +120,7 @@ def get_required_dataset_info(self) -> Dict[str, Any]:
120120
<https://github.com/automl/Auto-PyTorch/blob/refactor_development/autoPyTorch/constants.py>`
121121
"""
122122
info = super().get_required_dataset_info()
123+
assert self.task_type is not None, "Expected value for task type but got None"
123124
info.update({
124125
'numerical_columns': self.numerical_columns,
125126
'categorical_columns': self.categorical_columns,

autoPyTorch/ensemble/ensemble_builder.py

Lines changed: 12 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -91,21 +91,21 @@ def __init__(
9191
Both wrt to validation predictions
9292
If performance_range_threshold > 0, might return less models
9393
max_models_on_disc: Union[float, int]
94-
Defines the maximum number of models that are kept in the disc.
95-
If int, it must be greater or equal than 1, and dictates the max number of
96-
models to keep.
97-
If float, it will be interpreted as the max megabytes allowed of disc space. That
98-
is, if the number of ensemble candidates require more disc space than this float
99-
value, the worst models will be deleted to keep within this budget.
100-
Models and predictions of the worst-performing models will be deleted then.
101-
If None, the feature is disabled.
102-
It defines an upper bound on the models that can be used in the ensemble.
94+
Defines the maximum number of models that are kept in the disc.
95+
If int, it must be greater or equal than 1, and dictates the max number of
96+
models to keep.
97+
If float, it will be interpreted as the max megabytes allowed of disc space. That
98+
is, if the number of ensemble candidates require more disc space than this float
99+
value, the worst models will be deleted to keep within this budget.
100+
Models and predictions of the worst-performing models will be deleted then.
101+
If None, the feature is disabled.
102+
It defines an upper bound on the models that can be used in the ensemble.
103103
seed: int
104104
random seed
105105
max_iterations: int
106106
maximal number of iterations to run this script
107107
(default None --> deactivated)
108-
precision: [16,32,64,128]
108+
precision (int): [16,32,64,128]
109109
precision of floats to read the predictions
110110
memory_limit: Optional[int]
111111
memory limit in mb. If ``None``, no memory limit is enforced.
@@ -324,7 +324,7 @@ def fit_and_return_ensemble(
324324
It defines an upper bound on the models that can be used in the ensemble.
325325
seed: int
326326
random seed
327-
precision: [16,32,64,128]
327+
precision (int): [16,32,64,128]
328328
precision of floats to read the predictions
329329
memory_limit: Optional[int]
330330
memory limit in mb. If ``None``, no memory limit is enforced.
@@ -1506,15 +1506,7 @@ def _delete_excess_models(self, selected_keys: List[str]) -> None:
15061506
)
15071507

15081508
def _read_np_fn(self, path: str) -> np.ndarray:
1509-
1510-
# Support for string precision
1511-
if isinstance(self.precision, str):
1512-
precision = int(self.precision)
1513-
self.logger.warning("Interpreted str-precision as {}".format(
1514-
precision
1515-
))
1516-
else:
1517-
precision = self.precision
1509+
precision = self.precision
15181510

15191511
if path.endswith("gz"):
15201512
fp = gzip.open(path, 'rb')

autoPyTorch/ensemble/ensemble_selection.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -149,9 +149,9 @@ def _fit(
149149
if len(predictions) == 1:
150150
break
151151

152-
self.indices_ = order
153-
self.trajectory_ = trajectory
154-
self.train_loss_ = trajectory[-1]
152+
self.indices_: List[int] = order
153+
self.trajectory_: List[float] = trajectory
154+
self.train_loss_: float = trajectory[-1]
155155

156156
def _calculate_weights(self) -> None:
157157
"""

autoPyTorch/evaluation/abstract_evaluator.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
STRING_TO_TASK_TYPES,
3232
TABULAR_TASKS,
3333
)
34-
from autoPyTorch.datasets.base_dataset import BaseDataset
34+
from autoPyTorch.datasets.base_dataset import BaseDataset, BaseDatasetPropertiesType
3535
from autoPyTorch.evaluation.utils import (
3636
VotingRegressorWrapper,
3737
convert_multioutput_multiclass_to_multilabel
@@ -63,7 +63,7 @@ class MyTraditionalTabularClassificationPipeline(BaseEstimator):
6363
learning model, and is the final object that is stored for inference.
6464
6565
Attributes:
66-
dataset_properties (Dict[str, Any]):
66+
dataset_properties (Dict[str, BaseDatasetPropertiesType]):
6767
A dictionary containing dataset specific information
6868
random_state (Optional[np.random.RandomState]):
6969
Object that contains a seed and allows for reproducible results
@@ -73,8 +73,8 @@ class MyTraditionalTabularClassificationPipeline(BaseEstimator):
7373
"""
7474

7575
def __init__(self, config: str,
76-
dataset_properties: Dict[str, Any],
77-
random_state: Optional[np.random.RandomState] = None,
76+
dataset_properties: Dict[str, BaseDatasetPropertiesType],
77+
random_state: Optional[Union[int, np.random.RandomState]] = None,
7878
init_params: Optional[Dict] = None):
7979
self.config = config
8080
self.dataset_properties = dataset_properties
@@ -197,8 +197,6 @@ class DummyClassificationPipeline(DummyClassifier):
197197
worst performing model. In case of failure, at least this model will be fitted.
198198
199199
Attributes:
200-
dataset_properties (Dict[str, Any]):
201-
A dictionary containing dataset specific information
202200
random_state (Optional[Union[int, np.random.RandomState]]):
203201
Object that contains a seed and allows for reproducible results
204202
init_params (Optional[Dict]):
@@ -262,8 +260,6 @@ class DummyRegressionPipeline(DummyRegressor):
262260
worst performing model. In case of failure, at least this model will be fitted.
263261
264262
Attributes:
265-
dataset_properties (Dict[str, Any]):
266-
A dictionary containing dataset specific information
267263
random_state (Optional[Union[int, np.random.RandomState]]):
268264
Object that contains a seed and allows for reproducible results
269265
init_params (Optional[Dict]):

autoPyTorch/evaluation/tae.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def fit_predict_try_except_decorator(
3636
ta: typing.Callable,
3737
queue: multiprocessing.Queue, cost_for_crash: float, **kwargs: typing.Any) -> None:
3838
try:
39-
return ta(queue=queue, **kwargs)
39+
ta(queue=queue, **kwargs)
4040
except Exception as e:
4141
if isinstance(e, (MemoryError, pynisher.TimeoutException)):
4242
# Re-raise the memory error to let the pynisher handle that correctly
@@ -147,13 +147,15 @@ def __init__(
147147
self.exclude = exclude
148148
self.disable_file_output = disable_file_output
149149
self.init_params = init_params
150+
151+
self.budget_type = pipeline_config['budget_type'] if pipeline_config is not None else budget_type
152+
150153
self.pipeline_config: typing.Dict[str, typing.Union[int, str, float]] = dict()
151154
if pipeline_config is None:
152155
pipeline_config = replace_string_bool_to_bool(json.load(open(
153156
os.path.join(os.path.dirname(__file__), '../configs/default_pipeline_options.json'))))
154157
self.pipeline_config.update(pipeline_config)
155158

156-
self.budget_type = pipeline_config['budget_type'] if pipeline_config is not None else budget_type
157159
self.logger_port = logger_port
158160
if self.logger_port is None:
159161
self.logger: typing.Union[logging.Logger, PicklableClientLogger] = logging.getLogger("TAE")
@@ -237,7 +239,8 @@ def run_wrapper(
237239
run_info = run_info._replace(cutoff=int(np.ceil(run_info.cutoff)))
238240

239241
self.logger.info("Starting to evaluate configuration %s" % run_info.config.config_id)
240-
return super().run_wrapper(run_info=run_info)
242+
run_info, run_value = super().run_wrapper(run_info=run_info)
243+
return run_info, run_value
241244

242245
def run(
243246
self,

autoPyTorch/optimizer/smbo.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -213,12 +213,11 @@ def __init__(self,
213213

214214
self.search_space_updates = search_space_updates
215215

216-
dataset_name_ = "" if dataset_name is None else dataset_name
217216
if logger_port is None:
218217
self.logger_port = logging.handlers.DEFAULT_TCP_LOGGING_PORT
219218
else:
220219
self.logger_port = logger_port
221-
logger_name = '%s(%d):%s' % (self.__class__.__name__, self.seed, ":" + dataset_name_)
220+
logger_name = '%s(%d):%s' % (self.__class__.__name__, self.seed, ":" + self.dataset_name)
222221
self.logger = get_named_client_logger(name=logger_name,
223222
port=self.logger_port)
224223
self.logger.info("initialised {}".format(self.__class__.__name__))

0 commit comments

Comments
 (0)