Skip to content

Commit 3d2ae40

Browse files
committed
address comment
1 parent 975172a commit 3d2ae40

File tree

2 files changed

+30
-21
lines changed

2 files changed

+30
-21
lines changed

src/chronos/chronos.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,7 @@ def __init__(
169169
def _input_transform(
170170
self, context: torch.Tensor, scale: Optional[torch.Tensor] = None
171171
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
172+
context = context.to(dtype=torch.float32)
172173
attention_mask = ~torch.isnan(context)
173174

174175
if scale is None:
@@ -370,7 +371,10 @@ def left_pad_and_stack_1D(tensors: List[torch.Tensor]) -> torch.Tensor:
370371
assert isinstance(c, torch.Tensor)
371372
assert c.ndim == 1
372373
padding = torch.full(
373-
size=(max_len - len(c),), fill_value=torch.nan, device=c.device
374+
size=(max_len - len(c),),
375+
fill_value=torch.nan,
376+
device=c.device,
377+
dtype=c.dtype,
374378
)
375379
padded.append(torch.concat((padding, c), dim=-1))
376380
return torch.stack(padded)
@@ -397,7 +401,7 @@ class ChronosPipeline:
397401
model: ChronosModel
398402

399403
def _prepare_and_validate_context(
400-
self, context: Union[torch.Tensor, List[torch.Tensor]], dtype=torch.float32
404+
self, context: Union[torch.Tensor, List[torch.Tensor]]
401405
):
402406
if isinstance(context, list):
403407
context = left_pad_and_stack_1D(context)
@@ -406,7 +410,7 @@ def _prepare_and_validate_context(
406410
context = context.unsqueeze(0)
407411
assert context.ndim == 2
408412

409-
return context.to(dtype=dtype)
413+
return context
410414

411415
@torch.no_grad()
412416
def embed(
@@ -506,6 +510,9 @@ def predict(
506510
raise ValueError(msg)
507511
warnings.warn(msg)
508512

513+
input_dtype = context_tensor.dtype
514+
input_device = context_tensor.device
515+
509516
predictions = []
510517
remaining = prediction_length
511518

@@ -536,7 +543,7 @@ def predict(
536543
[context_tensor, prediction.median(dim=1).values], dim=-1
537544
)
538545

539-
return torch.cat(predictions, dim=-1)
546+
return torch.cat(predictions, dim=-1).to(dtype=input_dtype, device=input_device)
540547

541548
@classmethod
542549
def from_pretrained(cls, *args, **kwargs):

test/test_chronos.py

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -163,32 +163,33 @@ def validate_tensor(a: torch.Tensor, shape: Tuple[int, ...], dtype) -> None:
163163
assert a.dtype == dtype
164164

165165

166-
@pytest.mark.parametrize("torch_dtype", [torch.float32, torch.bfloat16])
167-
def test_pipeline_predict(torch_dtype: str):
166+
@pytest.mark.parametrize("model_dtype", [torch.float32, torch.bfloat16])
167+
@pytest.mark.parametrize("input_dtype", [torch.float32, torch.bfloat16])
168+
def test_pipeline_predict(model_dtype: torch.dtype, input_dtype: torch.dtype):
168169
pipeline = ChronosPipeline.from_pretrained(
169170
Path(__file__).parent / "dummy-chronos-model",
170171
device_map="cpu",
171-
torch_dtype=torch_dtype,
172+
torch_dtype=model_dtype,
172173
)
173-
context = 10 * torch.rand(size=(4, 16)) + 10
174+
context = 10 * torch.rand(size=(4, 16), dtype=input_dtype) + 10
174175

175176
# input: tensor of shape (batch_size, context_length)
176177

177178
samples = pipeline.predict(context, num_samples=12, prediction_length=3)
178-
validate_tensor(samples, shape=(4, 12, 3), dtype=torch.float32)
179+
validate_tensor(samples, shape=(4, 12, 3), dtype=input_dtype)
179180

180181
with pytest.raises(ValueError):
181182
samples = pipeline.predict(context, num_samples=7, prediction_length=65)
182183

183184
samples = pipeline.predict(
184185
context, num_samples=7, prediction_length=65, limit_prediction_length=False
185186
)
186-
validate_tensor(samples, shape=(4, 7, 65), dtype=torch.float32)
187+
validate_tensor(samples, shape=(4, 7, 65), dtype=input_dtype)
187188

188189
# input: batch_size-long list of tensors of shape (context_length,)
189190

190191
samples = pipeline.predict(list(context), num_samples=12, prediction_length=3)
191-
validate_tensor(samples, shape=(4, 12, 3), dtype=torch.float32)
192+
validate_tensor(samples, shape=(4, 12, 3), dtype=input_dtype)
192193

193194
with pytest.raises(ValueError):
194195
samples = pipeline.predict(list(context), num_samples=7, prediction_length=65)
@@ -199,12 +200,12 @@ def test_pipeline_predict(torch_dtype: str):
199200
prediction_length=65,
200201
limit_prediction_length=False,
201202
)
202-
validate_tensor(samples, shape=(4, 7, 65), dtype=torch.float32)
203+
validate_tensor(samples, shape=(4, 7, 65), dtype=input_dtype)
203204

204205
# input: tensor of shape (context_length,)
205206

206207
samples = pipeline.predict(context[0, ...], num_samples=12, prediction_length=3)
207-
validate_tensor(samples, shape=(1, 12, 3), dtype=torch.float32)
208+
validate_tensor(samples, shape=(1, 12, 3), dtype=input_dtype)
208209

209210
with pytest.raises(ValueError):
210211
samples = pipeline.predict(context[0, ...], num_samples=7, prediction_length=65)
@@ -215,40 +216,41 @@ def test_pipeline_predict(torch_dtype: str):
215216
prediction_length=65,
216217
limit_prediction_length=False,
217218
)
218-
validate_tensor(samples, shape=(1, 7, 65), dtype=torch.float32)
219+
validate_tensor(samples, shape=(1, 7, 65), dtype=input_dtype)
219220

220221

221-
@pytest.mark.parametrize("torch_dtype", [torch.float32, torch.bfloat16])
222-
def test_pipeline_embed(torch_dtype: str):
222+
@pytest.mark.parametrize("model_dtype", [torch.float32, torch.bfloat16])
223+
@pytest.mark.parametrize("input_dtype", [torch.float32, torch.bfloat16])
224+
def test_pipeline_embed(model_dtype: torch.dtype, input_dtype: torch.dtype):
223225
pipeline = ChronosPipeline.from_pretrained(
224226
Path(__file__).parent / "dummy-chronos-model",
225227
device_map="cpu",
226-
torch_dtype=torch_dtype,
228+
torch_dtype=model_dtype,
227229
)
228230
d_model = pipeline.model.model.config.d_model
229-
context = 10 * torch.rand(size=(4, 16)) + 10
231+
context = 10 * torch.rand(size=(4, 16), dtype=input_dtype) + 10
230232
expected_embed_length = 16 + (1 if pipeline.model.config.use_eos_token else 0)
231233

232234
# input: tensor of shape (batch_size, context_length)
233235

234236
embedding, scale = pipeline.embed(context)
235237
validate_tensor(
236-
embedding, shape=(4, expected_embed_length, d_model), dtype=torch_dtype
238+
embedding, shape=(4, expected_embed_length, d_model), dtype=model_dtype
237239
)
238240
validate_tensor(scale, shape=(4,), dtype=torch.float32)
239241

240242
# input: batch_size-long list of tensors of shape (context_length,)
241243

242244
embedding, scale = pipeline.embed(list(context))
243245
validate_tensor(
244-
embedding, shape=(4, expected_embed_length, d_model), dtype=torch_dtype
246+
embedding, shape=(4, expected_embed_length, d_model), dtype=model_dtype
245247
)
246248
validate_tensor(scale, shape=(4,), dtype=torch.float32)
247249

248250
# input: tensor of shape (context_length,)
249251
embedding, scale = pipeline.embed(context[0, ...])
250252
validate_tensor(
251-
embedding, shape=(1, expected_embed_length, d_model), dtype=torch_dtype
253+
embedding, shape=(1, expected_embed_length, d_model), dtype=model_dtype
252254
)
253255
validate_tensor(scale, shape=(1,), dtype=torch.float32)
254256

0 commit comments

Comments
 (0)