Skip to content

Split input_transform into context_input_transform and label_input_transform #82

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
May 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions scripts/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,9 +387,11 @@ def create_validation_data(self, data):

def to_hf_format(self, entry: dict) -> dict:
past_target = torch.tensor(entry["past_target"]).unsqueeze(0)
input_ids, attention_mask, scale = self.tokenizer.input_transform(past_target)
input_ids, attention_mask, scale = self.tokenizer.context_input_transform(
past_target
)
future_target = torch.tensor(entry["future_target"]).unsqueeze(0)
labels, labels_mask, _ = self.tokenizer.input_transform(future_target, scale)
labels, labels_mask = self.tokenizer.label_input_transform(future_target, scale)
labels[labels_mask == 0] = -100
return {
"input_ids": input_ids.squeeze(0),
Expand Down
117 changes: 84 additions & 33 deletions src/chronos/chronos.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,14 @@ class ChronosConfig:

tokenizer_class: str
tokenizer_kwargs: Dict[str, Any]
context_length: int
prediction_length: int
n_tokens: int
n_special_tokens: int
pad_token_id: int
eos_token_id: int
use_eos_token: bool
model_type: Literal["causal", "seq2seq"]
context_length: int
prediction_length: int
num_samples: int
temperature: float
top_k: int
Expand All @@ -59,27 +59,55 @@ class ChronosTokenizer:
which concrete classes must implement.
"""

def input_transform(
def context_input_transform(
self,
context: torch.Tensor,
tokenizer_state: Any = None,
) -> Tuple[torch.Tensor, torch.Tensor, Any]:
) -> Tuple:
"""
Turn a batch of time series into token IDs, attention map, and scale.
Turn a batch of time series into token IDs, attention map, and tokenizer_state.

Parameters
----------
context
A tensor shaped (batch_size, time_length), containing the
timeseries to forecast. Use left-padding with ``torch.nan``
to align time series of different lengths.

Returns
-------
token_ids
A tensor of integers, shaped (batch_size, time_length + 1)
if ``config.use_eos_token`` and (batch_size, time_length)
otherwise, containing token IDs for the input series.
attention_mask
A boolean tensor, same shape as ``token_ids``, indicating
which input observations are not ``torch.nan`` (i.e. not
missing nor padding).
tokenizer_state
An object returned by ``input_transform`` containing
An object that can be passed to ``label_input_transform``
and ``output_transform``. Contains the relevant information
to decode output samples into real values,
such as location and scale parameters.
"""
raise NotImplementedError()

def label_input_transform(self, label: torch.Tensor, tokenizer_state: Any) -> Tuple:
"""
Turn a batch of label slices of time series into token IDs and attention map
using the ``tokenizer_state`` provided by ``context_input_transform``.

Parameters
----------
context
A tensor shaped (batch_size, time_length), containing the
timeseries to forecast. Use left-padding with ``torch.nan``
to align time series of different lengths.
tokenizer_state
An object returned by ``context_input_transform`` containing
relevant information to preprocess data, such as location and
scale. The nature of this depends on the specific tokenizer.
This is useful when tokenizing the label (for training), in
order to use the same scaling used to tokenize the context;
when tokenizing the context, this argument should be ignored.
This is used for tokenizing the label, in order to use the same
scaling used to tokenize the context.

Returns
-------
Expand All @@ -91,10 +119,6 @@ def input_transform(
A boolean tensor, same shape as ``token_ids``, indicating
which input observations are not ``torch.nan`` (i.e. not
missing nor padding).
tokenizer_state
An object that will be passed to ``output_transform``.
Contains the relevant information to decode output samples into
real values, such as location and scale parameters.
"""
raise NotImplementedError()

Expand Down Expand Up @@ -141,14 +165,9 @@ def __init__(
)
)

def input_transform(
def _input_transform(
self, context: torch.Tensor, scale: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
batch_size, length = context.shape

if length > self.config.context_length:
context = context[..., -self.config.context_length :]

attention_mask = ~torch.isnan(context)

if scale is None:
Expand All @@ -170,16 +189,51 @@ def input_transform(
)
token_ids[~attention_mask] = self.config.pad_token_id

if self.config.use_eos_token:
eos_tokens = torch.full(
(batch_size, 1), fill_value=self.config.eos_token_id
return token_ids, attention_mask, scale

def _append_eos_token(
self, token_ids: torch.Tensor, attention_mask: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size = token_ids.shape[0]
eos_tokens = torch.full((batch_size, 1), fill_value=self.config.eos_token_id)
token_ids = torch.concat((token_ids, eos_tokens), dim=1)
eos_mask = torch.full((batch_size, 1), fill_value=True)
attention_mask = torch.concat((attention_mask, eos_mask), dim=1)

return token_ids, attention_mask

def context_input_transform(
self, context: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
length = context.shape[-1]

if length > self.config.context_length:
context = context[..., -self.config.context_length :]

token_ids, attention_mask, scale = self._input_transform(context=context)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering: is _input_transform needed, or could context_input_transform just piggy-back on label_input_transform here?


if self.config.use_eos_token and self.config.model_type == "seq2seq":
token_ids, attention_mask = self._append_eos_token(
token_ids=token_ids, attention_mask=attention_mask
)
token_ids = torch.concat((token_ids, eos_tokens), dim=1)
eos_mask = torch.full((batch_size, 1), fill_value=True)
attention_mask = torch.concat((attention_mask, eos_mask), dim=1)

return token_ids, attention_mask, scale

def label_input_transform(
self, label: torch.Tensor, scale: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
length = label.shape[-1]

assert length == self.config.prediction_length
token_ids, attention_mask, _ = self._input_transform(context=label, scale=scale)

if self.config.use_eos_token:
token_ids, attention_mask = self._append_eos_token(
token_ids=token_ids, attention_mask=attention_mask
)

return token_ids, attention_mask

def output_transform(
self, samples: torch.Tensor, scale: torch.Tensor
) -> torch.Tensor:
Expand Down Expand Up @@ -318,6 +372,7 @@ def left_pad_and_stack_1D(tensors: List[torch.Tensor]) -> torch.Tensor:
return torch.stack(padded)


@dataclass
class ChronosPipeline:
"""
A ``ChronosPipeline`` uses the given tokenizer and model to forecast
Expand All @@ -337,10 +392,6 @@ class ChronosPipeline:
tokenizer: ChronosTokenizer
model: ChronosModel

def __init__(self, tokenizer, model):
self.tokenizer = tokenizer
self.model = model

def _prepare_and_validate_context(
self, context: Union[torch.Tensor, List[torch.Tensor]]
):
Expand Down Expand Up @@ -380,8 +431,8 @@ def embed(
provided, and the extra 1 is for EOS.
"""
context_tensor = self._prepare_and_validate_context(context=context)
token_ids, attention_mask, tokenizer_state = self.tokenizer.input_transform(
context_tensor
token_ids, attention_mask, tokenizer_state = (
self.tokenizer.context_input_transform(context_tensor)
)
embeddings = self.model.encode(
input_ids=token_ids.to(self.model.device),
Expand Down Expand Up @@ -455,7 +506,7 @@ def predict(
remaining = prediction_length

while remaining > 0:
token_ids, attention_mask, scale = self.tokenizer.input_transform(
token_ids, attention_mask, scale = self.tokenizer.context_input_transform(
context_tensor
)
samples = self.model(
Expand Down
8 changes: 4 additions & 4 deletions test/test_chronos.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@ def test_tokenizer_consistency(n_numerical_tokens: int, n_special_tokens: int):
context = tokenizer.centers.unsqueeze(0) # add batch dimension
scale = torch.ones((1,)) # fix the scale to one to turn off scaling

token_ids, _, _ = tokenizer.input_transform(context, scale=scale)
token_ids, _, _ = tokenizer._input_transform(context, scale=scale)

samples = tokenizer.output_transform(
token_ids[:, :-1].unsqueeze(1), # remove final EOS, add sample dimension
token_ids.unsqueeze(1), # add sample dimension
scale=scale,
)

Expand Down Expand Up @@ -85,7 +85,7 @@ def test_tokenizer_fixed_data(
)
batch_size, _ = context.shape

token_ids, attention_mask, scale = tokenizer.input_transform(context)
token_ids, attention_mask, scale = tokenizer.context_input_transform(context)

assert token_ids.shape == (batch_size, context_length + 1 * use_eos_token)
assert all(token_ids[:, 0] == torch.tensor([0]).repeat(batch_size))
Expand Down Expand Up @@ -136,7 +136,7 @@ def test_tokenizer_random_data(use_eos_token: bool):
]
)

token_ids, attention_mask, scale = tokenizer.input_transform(context)
token_ids, attention_mask, scale = tokenizer.context_input_transform(context)

assert token_ids.shape == (
*context.shape[:-1],
Expand Down