diff --git a/src/chronos/chronos.py b/src/chronos/chronos.py index 175b5dbb..cc1d3fb6 100644 --- a/src/chronos/chronos.py +++ b/src/chronos/chronos.py @@ -5,6 +5,7 @@ from dataclasses import dataclass from typing import Any, Dict, List, Literal, Optional, Tuple, Union +import chronos import torch import torch.nn as nn from transformers import ( @@ -45,9 +46,8 @@ def __post_init__(self): ), f"Special token id's must be smaller than {self.n_special_tokens=}" def create_tokenizer(self) -> "ChronosTokenizer": - if self.tokenizer_class == "MeanScaleUniformBins": - return MeanScaleUniformBins(**self.tokenizer_kwargs, config=self) - raise ValueError + class_ = getattr(chronos, self.tokenizer_class) + return class_(**self.tokenizer_kwargs, config=self) class ChronosTokenizer: