-
Notifications
You must be signed in to change notification settings - Fork 405
Split input_transform
into context_input_transform
and label_input_transform
#82
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
e1b1bbf
Add abstractions
84260c5
Add abstractions for context/label input transform
2e9e85d
Push tests
849ec78
Remove duplicate code
dcb01f8
Fix mypy
7c45692
Revert abstractions
5f171d8
Fix type
9d992cb
Update docstring
e49ebce
Fix docstring
f8ca232
Fix
a961f59
Fix docstring
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -26,14 +26,14 @@ class ChronosConfig: | |
|
||
tokenizer_class: str | ||
tokenizer_kwargs: Dict[str, Any] | ||
context_length: int | ||
prediction_length: int | ||
n_tokens: int | ||
n_special_tokens: int | ||
pad_token_id: int | ||
eos_token_id: int | ||
use_eos_token: bool | ||
model_type: Literal["causal", "seq2seq"] | ||
context_length: int | ||
prediction_length: int | ||
num_samples: int | ||
temperature: float | ||
top_k: int | ||
|
@@ -59,27 +59,55 @@ class ChronosTokenizer: | |
which concrete classes must implement. | ||
""" | ||
|
||
def input_transform( | ||
def context_input_transform( | ||
self, | ||
context: torch.Tensor, | ||
tokenizer_state: Any = None, | ||
) -> Tuple[torch.Tensor, torch.Tensor, Any]: | ||
) -> Tuple: | ||
""" | ||
Turn a batch of time series into token IDs, attention map, and scale. | ||
Turn a batch of time series into token IDs, attention map, and tokenizer_state. | ||
|
||
Parameters | ||
---------- | ||
context | ||
A tensor shaped (batch_size, time_length), containing the | ||
timeseries to forecast. Use left-padding with ``torch.nan`` | ||
to align time series of different lengths. | ||
|
||
Returns | ||
------- | ||
token_ids | ||
A tensor of integers, shaped (batch_size, time_length + 1) | ||
if ``config.use_eos_token`` and (batch_size, time_length) | ||
otherwise, containing token IDs for the input series. | ||
attention_mask | ||
A boolean tensor, same shape as ``token_ids``, indicating | ||
which input observations are not ``torch.nan`` (i.e. not | ||
missing nor padding). | ||
tokenizer_state | ||
An object returned by ``input_transform`` containing | ||
An object that can be passed to ``label_input_transform`` | ||
and ``output_transform``. Contains the relevant information | ||
to decode output samples into real values, | ||
such as location and scale parameters. | ||
""" | ||
raise NotImplementedError() | ||
|
||
def label_input_transform(self, label: torch.Tensor, tokenizer_state: Any) -> Tuple: | ||
""" | ||
Turn a batch of label slices of time series into token IDs and attention map | ||
using the ``tokenizer_state`` provided by ``context_input_transform``. | ||
|
||
Parameters | ||
---------- | ||
context | ||
A tensor shaped (batch_size, time_length), containing the | ||
timeseries to forecast. Use left-padding with ``torch.nan`` | ||
to align time series of different lengths. | ||
tokenizer_state | ||
An object returned by ``context_input_transform`` containing | ||
relevant information to preprocess data, such as location and | ||
scale. The nature of this depends on the specific tokenizer. | ||
This is useful when tokenizing the label (for training), in | ||
order to use the same scaling used to tokenize the context; | ||
when tokenizing the context, this argument should be ignored. | ||
This is used for tokenizing the label, in order to use the same | ||
scaling used to tokenize the context. | ||
|
||
Returns | ||
------- | ||
|
@@ -91,10 +119,6 @@ def input_transform( | |
A boolean tensor, same shape as ``token_ids``, indicating | ||
which input observations are not ``torch.nan`` (i.e. not | ||
missing nor padding). | ||
tokenizer_state | ||
An object that will be passed to ``output_transform``. | ||
Contains the relevant information to decode output samples into | ||
real values, such as location and scale parameters. | ||
""" | ||
raise NotImplementedError() | ||
|
||
|
@@ -141,14 +165,9 @@ def __init__( | |
) | ||
) | ||
|
||
def input_transform( | ||
def _input_transform( | ||
self, context: torch.Tensor, scale: Optional[torch.Tensor] = None | ||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | ||
batch_size, length = context.shape | ||
|
||
if length > self.config.context_length: | ||
context = context[..., -self.config.context_length :] | ||
|
||
attention_mask = ~torch.isnan(context) | ||
|
||
if scale is None: | ||
|
@@ -170,16 +189,51 @@ def input_transform( | |
) | ||
token_ids[~attention_mask] = self.config.pad_token_id | ||
|
||
if self.config.use_eos_token: | ||
eos_tokens = torch.full( | ||
(batch_size, 1), fill_value=self.config.eos_token_id | ||
return token_ids, attention_mask, scale | ||
|
||
def _append_eos_token( | ||
self, token_ids: torch.Tensor, attention_mask: torch.Tensor | ||
) -> Tuple[torch.Tensor, torch.Tensor]: | ||
batch_size = token_ids.shape[0] | ||
eos_tokens = torch.full((batch_size, 1), fill_value=self.config.eos_token_id) | ||
token_ids = torch.concat((token_ids, eos_tokens), dim=1) | ||
eos_mask = torch.full((batch_size, 1), fill_value=True) | ||
attention_mask = torch.concat((attention_mask, eos_mask), dim=1) | ||
|
||
return token_ids, attention_mask | ||
|
||
def context_input_transform( | ||
self, context: torch.Tensor | ||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | ||
length = context.shape[-1] | ||
|
||
if length > self.config.context_length: | ||
context = context[..., -self.config.context_length :] | ||
|
||
token_ids, attention_mask, scale = self._input_transform(context=context) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm wondering: is |
||
|
||
if self.config.use_eos_token and self.config.model_type == "seq2seq": | ||
abdulfatir marked this conversation as resolved.
Show resolved
Hide resolved
|
||
token_ids, attention_mask = self._append_eos_token( | ||
token_ids=token_ids, attention_mask=attention_mask | ||
) | ||
token_ids = torch.concat((token_ids, eos_tokens), dim=1) | ||
eos_mask = torch.full((batch_size, 1), fill_value=True) | ||
attention_mask = torch.concat((attention_mask, eos_mask), dim=1) | ||
|
||
return token_ids, attention_mask, scale | ||
|
||
def label_input_transform( | ||
self, label: torch.Tensor, scale: torch.Tensor | ||
) -> Tuple[torch.Tensor, torch.Tensor]: | ||
length = label.shape[-1] | ||
|
||
assert length == self.config.prediction_length | ||
token_ids, attention_mask, _ = self._input_transform(context=label, scale=scale) | ||
|
||
if self.config.use_eos_token: | ||
token_ids, attention_mask = self._append_eos_token( | ||
token_ids=token_ids, attention_mask=attention_mask | ||
) | ||
|
||
return token_ids, attention_mask | ||
|
||
def output_transform( | ||
self, samples: torch.Tensor, scale: torch.Tensor | ||
) -> torch.Tensor: | ||
|
@@ -318,6 +372,7 @@ def left_pad_and_stack_1D(tensors: List[torch.Tensor]) -> torch.Tensor: | |
return torch.stack(padded) | ||
|
||
|
||
@dataclass | ||
class ChronosPipeline: | ||
""" | ||
A ``ChronosPipeline`` uses the given tokenizer and model to forecast | ||
|
@@ -337,10 +392,6 @@ class ChronosPipeline: | |
tokenizer: ChronosTokenizer | ||
model: ChronosModel | ||
|
||
def __init__(self, tokenizer, model): | ||
self.tokenizer = tokenizer | ||
self.model = model | ||
|
||
def _prepare_and_validate_context( | ||
self, context: Union[torch.Tensor, List[torch.Tensor]] | ||
): | ||
|
@@ -380,8 +431,8 @@ def embed( | |
provided, and the extra 1 is for EOS. | ||
""" | ||
context_tensor = self._prepare_and_validate_context(context=context) | ||
token_ids, attention_mask, tokenizer_state = self.tokenizer.input_transform( | ||
context_tensor | ||
token_ids, attention_mask, tokenizer_state = ( | ||
self.tokenizer.context_input_transform(context_tensor) | ||
) | ||
embeddings = self.model.encode( | ||
input_ids=token_ids.to(self.model.device), | ||
|
@@ -455,7 +506,7 @@ def predict( | |
remaining = prediction_length | ||
|
||
while remaining > 0: | ||
token_ids, attention_mask, scale = self.tokenizer.input_transform( | ||
token_ids, attention_mask, scale = self.tokenizer.context_input_transform( | ||
context_tensor | ||
) | ||
samples = self.model( | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.