@@ -26,14 +26,14 @@ class ChronosConfig:
26
26
27
27
tokenizer_class : str
28
28
tokenizer_kwargs : Dict [str , Any ]
29
+ context_length : int
30
+ prediction_length : int
29
31
n_tokens : int
30
32
n_special_tokens : int
31
33
pad_token_id : int
32
34
eos_token_id : int
33
35
use_eos_token : bool
34
36
model_type : Literal ["causal" , "seq2seq" ]
35
- context_length : int
36
- prediction_length : int
37
37
num_samples : int
38
38
temperature : float
39
39
top_k : int
@@ -59,27 +59,55 @@ class ChronosTokenizer:
59
59
which concrete classes must implement.
60
60
"""
61
61
62
- def input_transform (
62
+ def context_input_transform (
63
63
self ,
64
64
context : torch .Tensor ,
65
- tokenizer_state : Any = None ,
66
- ) -> Tuple [torch .Tensor , torch .Tensor , Any ]:
65
+ ) -> Tuple :
67
66
"""
68
- Turn a batch of time series into token IDs, attention map, and scale .
67
+ Turn a batch of time series into token IDs, attention map, and tokenizer_state .
69
68
70
69
Parameters
71
70
----------
72
71
context
73
72
A tensor shaped (batch_size, time_length), containing the
74
73
timeseries to forecast. Use left-padding with ``torch.nan``
75
74
to align time series of different lengths.
75
+
76
+ Returns
77
+ -------
78
+ token_ids
79
+ A tensor of integers, shaped (batch_size, time_length + 1)
80
+ if ``config.use_eos_token`` and (batch_size, time_length)
81
+ otherwise, containing token IDs for the input series.
82
+ attention_mask
83
+ A boolean tensor, same shape as ``token_ids``, indicating
84
+ which input observations are not ``torch.nan`` (i.e. not
85
+ missing nor padding).
76
86
tokenizer_state
77
- An object returned by ``input_transform`` containing
87
+ An object that can be passed to ``label_input_transform``
88
+ and ``output_transform``. Contains the relevant information
89
+ to decode output samples into real values,
90
+ such as location and scale parameters.
91
+ """
92
+ raise NotImplementedError ()
93
+
94
+ def label_input_transform (self , label : torch .Tensor , tokenizer_state : Any ) -> Tuple :
95
+ """
96
+ Turn a batch of label slices of time series into token IDs and attention map
97
+ using the ``tokenizer_state`` provided by ``context_input_transform``.
98
+
99
+ Parameters
100
+ ----------
101
+ context
102
+ A tensor shaped (batch_size, time_length), containing the
103
+ timeseries to forecast. Use left-padding with ``torch.nan``
104
+ to align time series of different lengths.
105
+ tokenizer_state
106
+ An object returned by ``context_input_transform`` containing
78
107
relevant information to preprocess data, such as location and
79
108
scale. The nature of this depends on the specific tokenizer.
80
- This is useful when tokenizing the label (for training), in
81
- order to use the same scaling used to tokenize the context;
82
- when tokenizing the context, this argument should be ignored.
109
+ This is used for tokenizing the label, in order to use the same
110
+ scaling used to tokenize the context.
83
111
84
112
Returns
85
113
-------
@@ -91,10 +119,6 @@ def input_transform(
91
119
A boolean tensor, same shape as ``token_ids``, indicating
92
120
which input observations are not ``torch.nan`` (i.e. not
93
121
missing nor padding).
94
- tokenizer_state
95
- An object that will be passed to ``output_transform``.
96
- Contains the relevant information to decode output samples into
97
- real values, such as location and scale parameters.
98
122
"""
99
123
raise NotImplementedError ()
100
124
@@ -141,14 +165,9 @@ def __init__(
141
165
)
142
166
)
143
167
144
- def input_transform (
168
+ def _input_transform (
145
169
self , context : torch .Tensor , scale : Optional [torch .Tensor ] = None
146
170
) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
147
- batch_size , length = context .shape
148
-
149
- if length > self .config .context_length :
150
- context = context [..., - self .config .context_length :]
151
-
152
171
attention_mask = ~ torch .isnan (context )
153
172
154
173
if scale is None :
@@ -170,16 +189,51 @@ def input_transform(
170
189
)
171
190
token_ids [~ attention_mask ] = self .config .pad_token_id
172
191
173
- if self .config .use_eos_token :
174
- eos_tokens = torch .full (
175
- (batch_size , 1 ), fill_value = self .config .eos_token_id
192
+ return token_ids , attention_mask , scale
193
+
194
+ def _append_eos_token (
195
+ self , token_ids : torch .Tensor , attention_mask : torch .Tensor
196
+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
197
+ batch_size = token_ids .shape [0 ]
198
+ eos_tokens = torch .full ((batch_size , 1 ), fill_value = self .config .eos_token_id )
199
+ token_ids = torch .concat ((token_ids , eos_tokens ), dim = 1 )
200
+ eos_mask = torch .full ((batch_size , 1 ), fill_value = True )
201
+ attention_mask = torch .concat ((attention_mask , eos_mask ), dim = 1 )
202
+
203
+ return token_ids , attention_mask
204
+
205
+ def context_input_transform (
206
+ self , context : torch .Tensor
207
+ ) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
208
+ length = context .shape [- 1 ]
209
+
210
+ if length > self .config .context_length :
211
+ context = context [..., - self .config .context_length :]
212
+
213
+ token_ids , attention_mask , scale = self ._input_transform (context = context )
214
+
215
+ if self .config .use_eos_token and self .config .model_type == "seq2seq" :
216
+ token_ids , attention_mask = self ._append_eos_token (
217
+ token_ids = token_ids , attention_mask = attention_mask
176
218
)
177
- token_ids = torch .concat ((token_ids , eos_tokens ), dim = 1 )
178
- eos_mask = torch .full ((batch_size , 1 ), fill_value = True )
179
- attention_mask = torch .concat ((attention_mask , eos_mask ), dim = 1 )
180
219
181
220
return token_ids , attention_mask , scale
182
221
222
+ def label_input_transform (
223
+ self , label : torch .Tensor , scale : torch .Tensor
224
+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
225
+ length = label .shape [- 1 ]
226
+
227
+ assert length == self .config .prediction_length
228
+ token_ids , attention_mask , _ = self ._input_transform (context = label , scale = scale )
229
+
230
+ if self .config .use_eos_token :
231
+ token_ids , attention_mask = self ._append_eos_token (
232
+ token_ids = token_ids , attention_mask = attention_mask
233
+ )
234
+
235
+ return token_ids , attention_mask
236
+
183
237
def output_transform (
184
238
self , samples : torch .Tensor , scale : torch .Tensor
185
239
) -> torch .Tensor :
@@ -318,6 +372,7 @@ def left_pad_and_stack_1D(tensors: List[torch.Tensor]) -> torch.Tensor:
318
372
return torch .stack (padded )
319
373
320
374
375
+ @dataclass
321
376
class ChronosPipeline :
322
377
"""
323
378
A ``ChronosPipeline`` uses the given tokenizer and model to forecast
@@ -337,10 +392,6 @@ class ChronosPipeline:
337
392
tokenizer : ChronosTokenizer
338
393
model : ChronosModel
339
394
340
- def __init__ (self , tokenizer , model ):
341
- self .tokenizer = tokenizer
342
- self .model = model
343
-
344
395
def _prepare_and_validate_context (
345
396
self , context : Union [torch .Tensor , List [torch .Tensor ]]
346
397
):
@@ -380,8 +431,8 @@ def embed(
380
431
provided, and the extra 1 is for EOS.
381
432
"""
382
433
context_tensor = self ._prepare_and_validate_context (context = context )
383
- token_ids , attention_mask , tokenizer_state = self . tokenizer . input_transform (
384
- context_tensor
434
+ token_ids , attention_mask , tokenizer_state = (
435
+ self . tokenizer . context_input_transform ( context_tensor )
385
436
)
386
437
embeddings = self .model .encode (
387
438
input_ids = token_ids .to (self .model .device ),
@@ -455,7 +506,7 @@ def predict(
455
506
remaining = prediction_length
456
507
457
508
while remaining > 0 :
458
- token_ids , attention_mask , scale = self .tokenizer .input_transform (
509
+ token_ids , attention_mask , scale = self .tokenizer .context_input_transform (
459
510
context_tensor
460
511
)
461
512
samples = self .model (
0 commit comments