Skip to content

Add support for causal models #113

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions scripts/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@
The output and checkpoints will be saved in `output/run-{id}/`.
> [!TIP]
> If the initial training step is too slow, you might want to change the `shuffle_buffer_length` and/or set `torch_compile` to `false`.

> [!IMPORTANT]
> When pretraining causal models (such as GPT2), the training script does [`LastValueImputation`](https://github.com/awslabs/gluonts/blob/f0f2266d520cb980f4c1ce18c28b003ad5cd2599/src/gluonts/transform/feature.py#L103) for missing values by default. If you pretrain causal models, please ensure that missing values are imputed similarly before passing the context tensor to `ChronosPipeline.predict()` for accurate results.
- (Optional) Once trained, you can easily push your fine-tuned model to HuggingFace🤗 Hub. Before that, do not forget to [create an access token](https://huggingface.co/settings/tokens) with **write permissions** and put it in `~/.cache/huggingface/token`. Here's a snippet that will push a fine-tuned model to HuggingFace🤗 Hub at `<your_hf_username>/chronos-t5-small-fine-tuned`.
```py
from chronos import ChronosPipeline
Expand Down
35 changes: 35 additions & 0 deletions scripts/training/configs/chronos-gpt2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
training_data_paths:
- "/home/ubuntu/tsmixup-data.arrow"
- "/home/ubuntu/kernelsynth-data.arrow"
probability:
- 0.9
- 0.1
context_length: 512
prediction_length: 64
min_past: 60
max_steps: 200_000
save_steps: 100_000
log_steps: 500
per_device_train_batch_size: 32
learning_rate: 0.001
optim: adamw_torch_fused
num_samples: 20
shuffle_buffer_length: 100_000
gradient_accumulation_steps: 1
model_id: openai-community/gpt2
model_type: causal
random_init: false
tie_embeddings: false
output_dir: ./output/
tf32: true
torch_compile: true
tokenizer_class: "MeanScaleUniformBins"
tokenizer_kwargs:
low_limit: -15.0
high_limit: 15.0
n_tokens: 4096
lr_scheduler_type: linear
warmup_ratio: 0.0
dataloader_num_workers: 1
max_missing_prop: 0.1
use_eos_token: true
60 changes: 57 additions & 3 deletions scripts/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@
ValidationSplitSampler,
InstanceSplitter,
ExpectedNumInstanceSampler,
MissingValueImputation,
LeavesMissingValues,
LastValueImputation,
)

from chronos import ChronosConfig, ChronosTokenizer
Expand Down Expand Up @@ -301,13 +304,16 @@ def __init__(
prediction_length: int = 64,
drop_prob: float = 0.2,
min_past: Optional[int] = None,
model_type: str = "seq2seq",
imputation_method: Optional[MissingValueImputation] = None,
mode: str = "training",
np_dtype=np.float32,
) -> None:
super().__init__()

assert len(probabilities) == len(datasets)
assert mode in ("training", "validation", "test")
assert model_type in ("seq2seq", "causal")

self.datasets = datasets
self.probabilities = probabilities
Expand All @@ -316,6 +322,8 @@ def __init__(
self.prediction_length = prediction_length
self.drop_prob = drop_prob
self.min_past = min_past or prediction_length
self.model_type = model_type
self.imputation_method = imputation_method or LeavesMissingValues()
self.mode = mode
self.np_dtype = np_dtype

Expand All @@ -324,6 +332,11 @@ def preprocess_entry(self, entry: dict, mode: str) -> dict:
entry["target"] = np.asarray(entry["target"], dtype=self.np_dtype)
assert entry["target"].ndim == 1, f"got {entry['target'].ndim=}, expected 1"

if self.model_type == "causal":
# Causal models do not play nice with missing values, so it is
# recommended to use an imputation method, e.g., LastValueImputation
entry["target"] = self.imputation_method(entry["target"])

if mode == "training" and self.drop_prob > 0:
target = entry["target"].copy()
drop_p = np.random.uniform(low=0.0, high=self.drop_prob)
Expand Down Expand Up @@ -386,6 +399,48 @@ def to_hf_format(self, entry: dict) -> dict:
future_target = torch.tensor(entry["future_target"]).unsqueeze(0)
labels, labels_mask = self.tokenizer.label_input_transform(future_target, scale)
labels[labels_mask == 0] = -100

if self.model_type == "causal":
# The InstanceSplitter pads time series on the left to be equal to the
# context_length. However, certain models (e.g., GPT2) with absolute
# position embeddings should not be trained with left padding.
# The following piece of code moves padding from left to right.

assert input_ids.shape[-1] == entry["past_is_pad"].shape[0]

# Find the index where padding starts
pad_start_idx = np.searchsorted(1 - entry["past_is_pad"], 1)
padded_input_ids, obs_input_ids = torch.tensor_split(
input_ids, [pad_start_idx], dim=-1
)
padded_attention_mask, obs_attention_mask = torch.tensor_split(
attention_mask, [pad_start_idx], dim=-1
)

# Move padding to the right
input_ids = torch.cat(
[
obs_input_ids,
labels,
padded_input_ids,
],
axis=-1,
)
attention_mask = torch.cat(
[
obs_attention_mask,
labels_mask,
padded_attention_mask,
],
axis=-1,
)

# labels for causal models are same as the input_ids.
# Internally transformers shifts the labels by one during training.
labels = input_ids.clone()
input_ids[~attention_mask] = self.tokenizer.config.pad_token_id
labels[~attention_mask] = -100

return {
"input_ids": input_ids.squeeze(0),
"attention_mask": attention_mask.squeeze(0),
Expand Down Expand Up @@ -520,9 +575,6 @@ def main(

assert model_type in ["seq2seq", "causal"]

if not model_type == "seq2seq":
raise NotImplementedError("Only seq2seq models are currently supported")

output_dir = get_next_path("run", base_dir=output_dir, file_type="")

log_on_main(f"Logging dir: {output_dir}", logger)
Expand Down Expand Up @@ -588,6 +640,8 @@ def main(
context_length=context_length,
prediction_length=prediction_length,
min_past=min_past,
model_type=model_type,
imputation_method=LastValueImputation() if model_type == "causal" else None,
mode="training",
).shuffle(shuffle_buffer_length=shuffle_buffer_length)

Expand Down
2 changes: 1 addition & 1 deletion src/chronos/chronos.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,7 +551,7 @@ def from_pretrained(cls, *args, **kwargs):
if chronos_config.model_type == "seq2seq":
inner_model = AutoModelForSeq2SeqLM.from_pretrained(*args, **kwargs)
else:
assert config.model_type == "causal"
assert chronos_config.model_type == "causal"
inner_model = AutoModelForCausalLM.from_pretrained(*args, **kwargs)

return cls(
Expand Down