You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Use relative algorithms to forecast the value of last 5 years with first 11 years.
Current Behavior
An error "ValueError: 'a' cannot be empty unless no samples are taken" occurs in evaluation period.
Possible Solution
It seems to be related to the "cheapest multi-fidelity of 1/9" setting mentioned in the paper, which might result in failures when evaluating an extremely short time-series data which resembles longley dataset with forecasting_horizon=5.
Your Code
importosimporttempfileastmpimportwarningsimportcopyos.environ['JOBLIB_TEMP_FOLDER'] =tmp.gettempdir()
os.environ['OMP_NUM_THREADS'] ='1'os.environ['OPENBLAS_NUM_THREADS'] ='1'os.environ['MKL_NUM_THREADS'] ='1'warnings.simplefilter(action='ignore', category=UserWarning)
warnings.simplefilter(action='ignore', category=FutureWarning)
fromsktime.datasetsimportload_longleytargets, features=load_longley()
""" this change from 3 to 5 leads to the ValueError """forecasting_horizon=5# Dataset optimized by APT-TS can be a list of np.ndarray/ pd.DataFrame where each series represents an element in the# list, or a single pd.DataFrame that records the series# index information: to which series the timestep belongs? This id can be stored as the DataFrame's index or a separate# column# Within each series, we take the last forecasting_horizon as test targets. The items before that as training targets# Normally the value to be forecasted should follow the training setsy_train= [targets[: -forecasting_horizon]]
y_test= [targets[-forecasting_horizon:]]
# same for features. For uni-variant models, X_train, X_test can be omitted and set as NoneX_train= [features[: -forecasting_horizon]]
# Here x_test indicates the 'known future features': they are the features known previously, features that are unknown# could be replaced with NAN or zeros (which will not be used by our networks). If no feature is known beforehand,# we could also omit X_testknown_future_features=list(features.columns)
X_test= [features[-forecasting_horizon:]]
start_times= [targets.index.to_timestamp()[0]]
freq='1Y'fromautoPyTorch.api.time_series_forecastingimportTimeSeriesForecastingTask############################################################################# Build and fit a forecaster# ==========================api=TimeSeriesForecastingTask()
############################################################################# Search for an ensemble of machine learning algorithms# =====================================================api.search(
X_train=X_train,
y_train=copy.deepcopy(y_train),
X_test=X_test,
optimize_metric='mean_MASE_forecasting',
n_prediction_steps=forecasting_horizon,
memory_limit=None, # Currently, forecasting models use much more memoriesfreq=freq,
start_times=start_times,
func_eval_time_limit_secs=50,
total_walltime_limit=60,
min_num_test_instances=1000, # proxy validation sets. This only works for the tasks with more than 1000 seriesknown_future_features=known_future_features,
)
fromautoPyTorch.datasets.time_series_datasetimportTimeSeriesSequencetest_sets= []
# We could construct test sets from scratchforfeature, future_feature, target, start_timeinzip(X_train, X_test,y_train, start_times):
test_sets.append(
TimeSeriesSequence(X=feature.values,
Y=target.values,
X_test=future_feature.values,
start_time=start_time,
is_test_set=True,
# additional information required to construct a new time series sequence**api.dataset.sequences_builder_kwargs
)
)
# Alternatively, if we only want to forecast the value after the X_train, we could directly ask datamanager to# generate a test set:# test_sets2 = api.dataset.generate_test_seqs()pred=api.predict(test_sets)
Error message
[INFO] [2022-08-07 11:05:48,498:Client-TAE] Starting to evaluate configuration 1
[DEBUG] [2022-08-07 11:05:48,504:Client-TAE] Search space updates for 2: <autoPyTorch.utils.hyperparameter_search_space_update.HyperparameterSearchSpaceUpdates object at 0x7f384a13eb20>
[DEBUG] [2022-08-07 11:05:48,504:Client-pynisher] Restricting your function to 49 seconds wall time.
[DEBUG] [2022-08-07 11:05:48,504:Client-pynisher] Allowing a grace period of 0 seconds.
[DEBUG] [2022-08-07 11:05:48,505:Client-pynisher] Function called with argument: (), {'queue': <multiprocessing.queues.Queue object at 0x7f390df9fd60>, 'config': Configuration(values={
'data_loader:backcast': False,
'data_loader:batch_size': 32,
'data_loader:num_batches_per_epoch': 50,
'data_loader:sample_strategy': 'SeqUniform',
'data_loader:transform_time_features': False,
'data_loader:window_size': 2,
'feature_encoding:__choice__': 'NoEncoder',
'loss:DistributionLoss:aggregation': 'median',
'loss:DistributionLoss:dist_cls': 'studentT',
'loss:DistributionLoss:forecast_strategy': 'sample',
'loss:DistributionLoss:num_samples': 100,
'loss:__choice__': 'DistributionLoss',
'lr_scheduler:ReduceLROnPlateau:factor': 0.5,
'lr_scheduler:ReduceLROnPlateau:mode': 'max',
'lr_scheduler:ReduceLROnPlateau:patience': 10,
'lr_scheduler:__choice__': 'ReduceLROnPlateau',
'network_backbone:__choice__': 'flat_encoder',
'network_backbone:flat_encoder:MLPDecoder:has_local_layer': True,
'network_backbone:flat_encoder:MLPDecoder:num_layers': 0,
'network_backbone:flat_encoder:MLPDecoder:units_local_layer': 40,
'network_backbone:flat_encoder:MLPEncoder:activation': 'relu',
'network_backbone:flat_encoder:MLPEncoder:normalization': 'NoNorm',
'network_backbone:flat_encoder:MLPEncoder:num_groups': 1,
'network_backbone:flat_encoder:MLPEncoder:num_units_1': 40,
'network_backbone:flat_encoder:MLPEncoder:use_dropout': False,
'network_backbone:flat_encoder:__choice__': 'MLPEncoder',
'network_embedding:__choice__': 'NoEmbedding',
'network_init:XavierInit:bias_strategy': 'Normal',
'network_init:__choice__': 'XavierInit',
'optimizer:AdamOptimizer:beta1': 0.9,
'optimizer:AdamOptimizer:beta2': 0.999,
'optimizer:AdamOptimizer:lr': 0.001,
'optimizer:AdamOptimizer:weight_decay': 1e-08,
'optimizer:__choice__': 'AdamOptimizer',
'scaler:scaling_mode': 'standard',
'target_scaler:scaling_mode': 'mean_abs',
'trainer:__choice__': 'ForecastingStandardTrainer',
})
, 'backend': <autoPyTorch.automl_common.common.utils.backend.Backend object at 0x7f390dff1f70>, 'metric': mean_MASE_forecasting, 'seed': 1, 'num_run': 2, 'output_y_hat_optimization': True, 'include': {}, 'exclude': {}, 'disable_file_output': [], 'instance': '{"task_id": "d25c0a6a-15fd-11ed-8fee-1dfac5c47e62"}', 'init_params': {'instance': '{"task_id": "d25c0a6a-15fd-11ed-8fee-1dfac5c47e62"}'}, 'budget': 5.555555555555555, 'budget_type': 'epochs', 'pipeline_config': {'device': 'cuda', 'budget_type': 'epochs', 'epochs': 50, 'runtime': 3600, 'torch_num_threads': 1, 'early_stopping': 20, 'use_tensorboard_logger': False, 'metrics_during_training': True, 'optimize_metric': 'mean_MASE_forecasting'}, 'logger_port': 46370, 'all_supported_metrics': True, 'search_space_updates': <autoPyTorch.utils.hyperparameter_search_space_update.HyperparameterSearchSpaceUpdates object at 0x7f384a13eb20>}
[DEBUG] [2022-08-07 11:05:51,128:Client-pynisher] Redirecting output of the function to files. Access them via the stdout and stderr attributes of the wrapped function.
[DEBUG] [2022-08-07 11:05:51,129:Client-pynisher] call function
[DEBUG] [2022-08-07 11:05:51,503:Client-TimeSeriesForecastingTrainEvaluator(1)] Fit dictionary in Abstract evaluator: dataset_properties: {'task_type': 'time_series_forecasting', 'categorical_features': [], 'feature_shapes': {'GNPDEFL': 1, 'GNP': 1, 'UNEMP': 1, 'ARMED': 1, 'POP': 1}, 'input_shape': (6, 5), 'sequence_lengths_train': array([6]), 'static_features': (), 'time_feature_names': ('time_feature_Constant',), 'categorical_columns': [], 'numerical_columns': [0, 1, 2, 3, 4], 'numerical_features': [0, 1, 2, 3, 4], 'output_shape': [5, 1], 'n_prediction_steps': 5, 'freq': '1Y', 'categories': [], 'feature_names': ('GNPDEFL', 'GNP', 'UNEMP', 'ARMED', 'POP'), 'is_small_preprocess': True, 'known_future_features': (), 'output_type': 'continuous', 'issparse': False, 'sp': 1, 'time_feature_transform': [Constant()], 'uni_variant': False, 'static_features_shape': 0, 'future_feature_shapes': (5, 0), 'targets_have_missing_values': False, 'encoder_can_be_auto_regressive': False, 'features_have_missing_values': False}
additional_metrics: ['mean_MASE_forecasting', 'mean_MASE_forecasting', 'median_MASE_forecasting', 'mean_MAE_forecasting', 'median_MAE_forecasting', 'mean_MAPE_forecasting', 'median_MAPE_forecasting', 'mean_MSE_forecasting', 'median_MSE_forecasting']
X_train: GNPDEFL GNP UNEMP ARMED POP
0 83.0 234289.0 2356.0 1590.0 107608.0
0 88.5 259426.0 2325.0 1456.0 108632.0
0 88.2 258054.0 3682.0 1616.0 109773.0
0 89.5 284599.0 3351.0 1650.0 110929.0
0 96.2 328975.0 2099.0 3099.0 112075.0
0 98.1 346999.0 1932.0 3594.0 113270.0
0 99.0 365385.0 1870.0 3547.0 115094.0
0 100.0 363112.0 3578.0 3350.0 116219.0
0 101.2 397469.0 2904.0 3048.0 117388.0
0 104.6 419180.0 2822.0 2857.0 118734.0
0 108.4 442769.0 2936.0 2798.0 120445.0
y_train: 0
0 60323.0
0 61122.0
0 60171.0
0 61187.0
0 63221.0
0 63639.0
0 64989.0
0 63761.0
0 66019.0
0 67857.0
0 68169.0
X_test: GNPDEFL GNP UNEMP ARMED POP
0 110.8 444546.0 4681.0 2637.0 121950.0
0 112.6 482704.0 3813.0 2552.0 123366.0
0 114.2 502601.0 3931.0 2514.0 125368.0
0 115.7 518173.0 4806.0 2572.0 127852.0
0 116.9 554894.0 4007.0 2827.0 130081.0
y_test: [[66513.00000001]
[68655.00000001]
[69564.00000001]
[69331.00000001]
[70551.00000001]]
backend: <autoPyTorch.automl_common.common.utils.backend.Backend object at 0x7f76c50b2bb0>
logger_port: 46370
optimize_metric: mean_MASE_forecasting
device: cuda
budget_type: epochs
epochs: 5.555555555555555
torch_num_threads: 1
early_stopping: 20
use_tensorboard_logger: False
metrics_during_training: True
[DEBUG] [2022-08-07 11:05:51,503:Client-TimeSeriesForecastingTrainEvaluator(1)] Search space updates :<autoPyTorch.utils.hyperparameter_search_space_update.HyperparameterSearchSpaceUpdates object at 0x7f76c50b2f70>
[DEBUG] [2022-08-07 11:05:51,503:Client-TimeSeriesForecastingTrainEvaluator(1)] Search space updates :<autoPyTorch.utils.hyperparameter_search_space_update.HyperparameterSearchSpaceUpdates object at 0x7f76c50b2f70>
[INFO] [2022-08-07 11:05:51,504:Client-TimeSeriesForecastingTrainEvaluator(1)] Starting fit 0
[DEBUG] [2022-08-07 11:05:51,877:Client-pynisher] function returned properly: (None, 0)
[DEBUG] [2022-08-07 11:05:51,877:Client-pynisher] return value: (None, 0)
[DEBUG] [2022-08-07 11:05:52,304:Client-TAE] Finish function evaluation 2.
Status: StatusType.CRASHED, Cost: 2147483647.0, Runtime: 3.372134208679199,
Additional information:
traceback: Traceback (most recent call last):
File "/home/robby/miniconda3/envs/auto-pytorch/lib/python3.8/site-packages/autoPyTorch/evaluation/tae.py", line 61, in fit_predict_try_except_decorator
ta(queue=queue, **kwargs)
File "/home/robby/miniconda3/envs/auto-pytorch/lib/python3.8/site-packages/autoPyTorch/evaluation/time_series_forecasting_train_evaluator.py", line 558, in forecasting_eval_train_function
evaluator.fit_predict_and_loss()
File "/home/robby/miniconda3/envs/auto-pytorch/lib/python3.8/site-packages/autoPyTorch/evaluation/time_series_forecasting_train_evaluator.py", line 176, in fit_predict_and_loss
y_train_pred, y_opt_pred, y_valid_pred, y_test_pred = self._fit_and_predict(pipeline, split_id,
File "/home/robby/miniconda3/envs/auto-pytorch/lib/python3.8/site-packages/autoPyTorch/evaluation/train_evaluator.py", line 364, in _fit_and_predict
fit_and_suppress_warnings(self.logger, pipeline, X, y)
File "/home/robby/miniconda3/envs/auto-pytorch/lib/python3.8/site-packages/autoPyTorch/evaluation/abstract_evaluator.py", line 338, in fit_and_suppress_warnings
pipeline.fit(X, y)
File "/home/robby/miniconda3/envs/auto-pytorch/lib/python3.8/site-packages/autoPyTorch/pipeline/base_pipeline.py", line 157, in fit
X, fit_params = self.fit_transformer(X, y, **fit_params)
File "/home/robby/miniconda3/envs/auto-pytorch/lib/python3.8/site-packages/autoPyTorch/pipeline/base_pipeline.py", line 169, in fit_transformer
Xt = self._fit(X, y, **fit_params_steps)
File "/home/robby/miniconda3/envs/auto-pytorch/lib/python3.8/site-packages/sklearn/pipeline.py", line 303, in _fit
X, fitted_transformer = fit_transform_one_cached(
File "/home/robby/miniconda3/envs/auto-pytorch/lib/python3.8/site-packages/joblib/memory.py", line 349, in __call__
return self.func(*args, **kwargs)
File "/home/robby/miniconda3/envs/auto-pytorch/lib/python3.8/site-packages/sklearn/pipeline.py", line 756, in _fit_transform_one
res = transformer.fit(X, y, **fit_params).transform(X)
File "/home/robby/miniconda3/envs/auto-pytorch/lib/python3.8/site-packages/autoPyTorch/pipeline/components/training/data_loader/time_series_forecasting_data_loader.py", line 299, in fit
num_instances_per_seqs = self.compute_expected_num_instances_per_seq(num_instances_dataset,
File "/home/robby/miniconda3/envs/auto-pytorch/lib/python3.8/site-packages/autoPyTorch/pipeline/components/training/data_loader/time_series_forecasting_data_loader.py", line 178, in compute_expected_num_instances_per_seq
seq_idx_inactivate = self.random_state.choice(seq_idx_inactivate, len(seq_idx_inactivate) - 1,
File "mtrand.pyx", line 915, in numpy.random.mtrand.RandomState.choice
ValueError: 'a' cannot be empty unless no samples are taken
error: ValueError("'a' cannot be empty unless no samples are taken")
further information can be found in following complete log file. a-cannot-be-empty.log
Your Local environment
Operating System, version
Ubuntu 20.04
Python, version
Python 3.8
Outputs of pip freeze or conda list
Pytorch=1.12+cu116
Make sure to add all the information needed to understand the bug so that someone can help.
If the info is missing, we'll add the 'Needs more information' label and close the issue until there is enough information.
The text was updated successfully, but these errors were encountered:
Hi,
this resulted from the unsuited forecasting_horizon given the dataset:
in this case, the training /validation set only has one series with a length of 11. As we want to reserve the last 5 (forecasting_horizon) element of the training/validation set as the validation set. only 6 elements remain in the training set. However, our current implementation requires the targets to have a length of 'forecasting_horizon', this might result in an empty training split for that specific series.
In practice, we could expect that one dataset has multiple series, and the other series could obtain the information contained in that series. This example is only a toy example showing how to apply AutoPyTorch to time series tasks. But yes, we should raise an error if the forecast horizon does not fit the dataset.
NOTE: ISSUES ARE NOT FOR CODE HELP - Ask for Help at https://stackoverflow.com
Your issue may already be reported!
Also, please search on the issue tracker before creating one.
Issue Description
forecasting_horizon=5
in line 26 of https://github.com/automl/Auto-PyTorch/blob/master/examples/20_basics/example_time_series_forecasting.py.Expected Behavior
Use relative algorithms to forecast the value of last 5 years with first 11 years.
Current Behavior
An error "ValueError: 'a' cannot be empty unless no samples are taken" occurs in evaluation period.
Possible Solution
It seems to be related to the "cheapest multi-fidelity of 1/9" setting mentioned in the paper, which might result in failures when evaluating an extremely short time-series data which resembles longley dataset with
forecasting_horizon=5
.Your Code
Error message
further information can be found in following complete log file.
a-cannot-be-empty.log
Your Local environment
Ubuntu 20.04
Python 3.8
pip freeze
orconda list
Pytorch=1.12+cu116
Make sure to add all the information needed to understand the bug so that someone can help.
If the info is missing, we'll add the 'Needs more information' label and close the issue until there is enough information.
The text was updated successfully, but these errors were encountered: