Skip to content

Commit 8a1bc0f

Browse files
authored
Make observed_value_field optional in TFTInstanceSplitter (#3259)
*Issue #, if available:* N/A *Description of changes:* Previously, `TFTInstanceSplitter` assumed that the `observed_value_field` is always present in the entry. This PR fixes that. By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice. **Please tag this pr with at least one of these labels to make our release process faster:** BREAKING, new feature, bug fix, other change, dev setup
1 parent b36d0ce commit 8a1bc0f

File tree

3 files changed

+9
-7
lines changed

3 files changed

+9
-7
lines changed
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,4 @@
1+
# scipy cap can be removed once this is resolved: https://github.com/statsmodels/statsmodels/issues/9584
2+
scipy<1.16.0; python_version > "3.7.0"
3+
scipy~=1.7.3; python_version <= "3.7.0"
14
statsforecast~=1.0

requirements/requirements-pytorch.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,6 @@ torch>=1.9,<3
22
lightning>=2.2.2,<2.5
33
# Capping `lightning` does not cap `pytorch_lightning`, so we cap manually
44
pytorch_lightning>=2.2.2,<2.5
5-
scipy~=1.10; python_version > "3.7.0"
5+
# scipy cap can be removed once this is resolved: https://github.com/statsmodels/statsmodels/issues/9584
6+
scipy<1.16.0; python_version > "3.7.0"
67
scipy~=1.7.3; python_version <= "3.7.0"

src/gluonts/transform/split.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -494,7 +494,7 @@ def __init__(
494494
is_pad_field: str = FieldName.IS_PAD,
495495
start_field: str = FieldName.START,
496496
forecast_start_field: str = FieldName.FORECAST_START,
497-
observed_value_field: str = FieldName.OBSERVED_VALUES,
497+
observed_value_field: Optional[str] = FieldName.OBSERVED_VALUES,
498498
lead_time: int = 0,
499499
output_NTC: bool = True,
500500
time_series_fields: List[str] = [],
@@ -529,11 +529,9 @@ def flatmap_transform(
529529

530530
sampled_indices = self.instance_sampler(target)
531531

532-
slice_cols = (
533-
self.ts_fields
534-
+ self.past_ts_fields
535-
+ [self.target_field, self.observed_value_field]
536-
)
532+
slice_cols = self.ts_fields + self.past_ts_fields + [self.target_field]
533+
if self.observed_value_field is not None:
534+
slice_cols.append(self.observed_value_field)
537535
for i in sampled_indices:
538536
pad_length = max(self.past_length - i, 0)
539537
d = data.copy()

0 commit comments

Comments
 (0)