Skip to content

Commit 223e576

Browse files
abdulfatirAbdul Fatir Ansari
andauthored
Split input_transform into context_input_transform and label_input_transform (#82)
*Description of changes:* This splits `input_transform` into `context_input_transform` and `label_input_transform`. Previously, `input_transform` was being used for both context and label during training which would lead to incorrect results where `prediction_length` > `context_length`. TODO: - [x] Update docstrings - [x] Test the training script By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice. --------- Co-authored-by: Abdul Fatir Ansari <[email protected]>
1 parent ea26e3d commit 223e576

File tree

3 files changed

+92
-39
lines changed

3 files changed

+92
-39
lines changed

scripts/training/train.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -387,9 +387,11 @@ def create_validation_data(self, data):
387387

388388
def to_hf_format(self, entry: dict) -> dict:
389389
past_target = torch.tensor(entry["past_target"]).unsqueeze(0)
390-
input_ids, attention_mask, scale = self.tokenizer.input_transform(past_target)
390+
input_ids, attention_mask, scale = self.tokenizer.context_input_transform(
391+
past_target
392+
)
391393
future_target = torch.tensor(entry["future_target"]).unsqueeze(0)
392-
labels, labels_mask, _ = self.tokenizer.input_transform(future_target, scale)
394+
labels, labels_mask = self.tokenizer.label_input_transform(future_target, scale)
393395
labels[labels_mask == 0] = -100
394396
return {
395397
"input_ids": input_ids.squeeze(0),

src/chronos/chronos.py

Lines changed: 84 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,14 @@ class ChronosConfig:
2626

2727
tokenizer_class: str
2828
tokenizer_kwargs: Dict[str, Any]
29+
context_length: int
30+
prediction_length: int
2931
n_tokens: int
3032
n_special_tokens: int
3133
pad_token_id: int
3234
eos_token_id: int
3335
use_eos_token: bool
3436
model_type: Literal["causal", "seq2seq"]
35-
context_length: int
36-
prediction_length: int
3737
num_samples: int
3838
temperature: float
3939
top_k: int
@@ -59,27 +59,55 @@ class ChronosTokenizer:
5959
which concrete classes must implement.
6060
"""
6161

62-
def input_transform(
62+
def context_input_transform(
6363
self,
6464
context: torch.Tensor,
65-
tokenizer_state: Any = None,
66-
) -> Tuple[torch.Tensor, torch.Tensor, Any]:
65+
) -> Tuple:
6766
"""
68-
Turn a batch of time series into token IDs, attention map, and scale.
67+
Turn a batch of time series into token IDs, attention map, and tokenizer_state.
6968
7069
Parameters
7170
----------
7271
context
7372
A tensor shaped (batch_size, time_length), containing the
7473
timeseries to forecast. Use left-padding with ``torch.nan``
7574
to align time series of different lengths.
75+
76+
Returns
77+
-------
78+
token_ids
79+
A tensor of integers, shaped (batch_size, time_length + 1)
80+
if ``config.use_eos_token`` and (batch_size, time_length)
81+
otherwise, containing token IDs for the input series.
82+
attention_mask
83+
A boolean tensor, same shape as ``token_ids``, indicating
84+
which input observations are not ``torch.nan`` (i.e. not
85+
missing nor padding).
7686
tokenizer_state
77-
An object returned by ``input_transform`` containing
87+
An object that can be passed to ``label_input_transform``
88+
and ``output_transform``. Contains the relevant information
89+
to decode output samples into real values,
90+
such as location and scale parameters.
91+
"""
92+
raise NotImplementedError()
93+
94+
def label_input_transform(self, label: torch.Tensor, tokenizer_state: Any) -> Tuple:
95+
"""
96+
Turn a batch of label slices of time series into token IDs and attention map
97+
using the ``tokenizer_state`` provided by ``context_input_transform``.
98+
99+
Parameters
100+
----------
101+
context
102+
A tensor shaped (batch_size, time_length), containing the
103+
timeseries to forecast. Use left-padding with ``torch.nan``
104+
to align time series of different lengths.
105+
tokenizer_state
106+
An object returned by ``context_input_transform`` containing
78107
relevant information to preprocess data, such as location and
79108
scale. The nature of this depends on the specific tokenizer.
80-
This is useful when tokenizing the label (for training), in
81-
order to use the same scaling used to tokenize the context;
82-
when tokenizing the context, this argument should be ignored.
109+
This is used for tokenizing the label, in order to use the same
110+
scaling used to tokenize the context.
83111
84112
Returns
85113
-------
@@ -91,10 +119,6 @@ def input_transform(
91119
A boolean tensor, same shape as ``token_ids``, indicating
92120
which input observations are not ``torch.nan`` (i.e. not
93121
missing nor padding).
94-
tokenizer_state
95-
An object that will be passed to ``output_transform``.
96-
Contains the relevant information to decode output samples into
97-
real values, such as location and scale parameters.
98122
"""
99123
raise NotImplementedError()
100124

@@ -141,14 +165,9 @@ def __init__(
141165
)
142166
)
143167

144-
def input_transform(
168+
def _input_transform(
145169
self, context: torch.Tensor, scale: Optional[torch.Tensor] = None
146170
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
147-
batch_size, length = context.shape
148-
149-
if length > self.config.context_length:
150-
context = context[..., -self.config.context_length :]
151-
152171
attention_mask = ~torch.isnan(context)
153172

154173
if scale is None:
@@ -170,16 +189,51 @@ def input_transform(
170189
)
171190
token_ids[~attention_mask] = self.config.pad_token_id
172191

173-
if self.config.use_eos_token:
174-
eos_tokens = torch.full(
175-
(batch_size, 1), fill_value=self.config.eos_token_id
192+
return token_ids, attention_mask, scale
193+
194+
def _append_eos_token(
195+
self, token_ids: torch.Tensor, attention_mask: torch.Tensor
196+
) -> Tuple[torch.Tensor, torch.Tensor]:
197+
batch_size = token_ids.shape[0]
198+
eos_tokens = torch.full((batch_size, 1), fill_value=self.config.eos_token_id)
199+
token_ids = torch.concat((token_ids, eos_tokens), dim=1)
200+
eos_mask = torch.full((batch_size, 1), fill_value=True)
201+
attention_mask = torch.concat((attention_mask, eos_mask), dim=1)
202+
203+
return token_ids, attention_mask
204+
205+
def context_input_transform(
206+
self, context: torch.Tensor
207+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
208+
length = context.shape[-1]
209+
210+
if length > self.config.context_length:
211+
context = context[..., -self.config.context_length :]
212+
213+
token_ids, attention_mask, scale = self._input_transform(context=context)
214+
215+
if self.config.use_eos_token and self.config.model_type == "seq2seq":
216+
token_ids, attention_mask = self._append_eos_token(
217+
token_ids=token_ids, attention_mask=attention_mask
176218
)
177-
token_ids = torch.concat((token_ids, eos_tokens), dim=1)
178-
eos_mask = torch.full((batch_size, 1), fill_value=True)
179-
attention_mask = torch.concat((attention_mask, eos_mask), dim=1)
180219

181220
return token_ids, attention_mask, scale
182221

222+
def label_input_transform(
223+
self, label: torch.Tensor, scale: torch.Tensor
224+
) -> Tuple[torch.Tensor, torch.Tensor]:
225+
length = label.shape[-1]
226+
227+
assert length == self.config.prediction_length
228+
token_ids, attention_mask, _ = self._input_transform(context=label, scale=scale)
229+
230+
if self.config.use_eos_token:
231+
token_ids, attention_mask = self._append_eos_token(
232+
token_ids=token_ids, attention_mask=attention_mask
233+
)
234+
235+
return token_ids, attention_mask
236+
183237
def output_transform(
184238
self, samples: torch.Tensor, scale: torch.Tensor
185239
) -> torch.Tensor:
@@ -318,6 +372,7 @@ def left_pad_and_stack_1D(tensors: List[torch.Tensor]) -> torch.Tensor:
318372
return torch.stack(padded)
319373

320374

375+
@dataclass
321376
class ChronosPipeline:
322377
"""
323378
A ``ChronosPipeline`` uses the given tokenizer and model to forecast
@@ -337,10 +392,6 @@ class ChronosPipeline:
337392
tokenizer: ChronosTokenizer
338393
model: ChronosModel
339394

340-
def __init__(self, tokenizer, model):
341-
self.tokenizer = tokenizer
342-
self.model = model
343-
344395
def _prepare_and_validate_context(
345396
self, context: Union[torch.Tensor, List[torch.Tensor]]
346397
):
@@ -380,8 +431,8 @@ def embed(
380431
provided, and the extra 1 is for EOS.
381432
"""
382433
context_tensor = self._prepare_and_validate_context(context=context)
383-
token_ids, attention_mask, tokenizer_state = self.tokenizer.input_transform(
384-
context_tensor
434+
token_ids, attention_mask, tokenizer_state = (
435+
self.tokenizer.context_input_transform(context_tensor)
385436
)
386437
embeddings = self.model.encode(
387438
input_ids=token_ids.to(self.model.device),
@@ -455,7 +506,7 @@ def predict(
455506
remaining = prediction_length
456507

457508
while remaining > 0:
458-
token_ids, attention_mask, scale = self.tokenizer.input_transform(
509+
token_ids, attention_mask, scale = self.tokenizer.context_input_transform(
459510
context_tensor
460511
)
461512
samples = self.model(

test/test_chronos.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,10 @@ def test_tokenizer_consistency(n_numerical_tokens: int, n_special_tokens: int):
3838
context = tokenizer.centers.unsqueeze(0) # add batch dimension
3939
scale = torch.ones((1,)) # fix the scale to one to turn off scaling
4040

41-
token_ids, _, _ = tokenizer.input_transform(context, scale=scale)
41+
token_ids, _, _ = tokenizer._input_transform(context, scale=scale)
4242

4343
samples = tokenizer.output_transform(
44-
token_ids[:, :-1].unsqueeze(1), # remove final EOS, add sample dimension
44+
token_ids.unsqueeze(1), # add sample dimension
4545
scale=scale,
4646
)
4747

@@ -85,7 +85,7 @@ def test_tokenizer_fixed_data(
8585
)
8686
batch_size, _ = context.shape
8787

88-
token_ids, attention_mask, scale = tokenizer.input_transform(context)
88+
token_ids, attention_mask, scale = tokenizer.context_input_transform(context)
8989

9090
assert token_ids.shape == (batch_size, context_length + 1 * use_eos_token)
9191
assert all(token_ids[:, 0] == torch.tensor([0]).repeat(batch_size))
@@ -136,7 +136,7 @@ def test_tokenizer_random_data(use_eos_token: bool):
136136
]
137137
)
138138

139-
token_ids, attention_mask, scale = tokenizer.input_transform(context)
139+
token_ids, attention_mask, scale = tokenizer.context_input_transform(context)
140140

141141
assert token_ids.shape == (
142142
*context.shape[:-1],

0 commit comments

Comments
 (0)