Skip to content

Commit 4c43cfb

Browse files
abdulfatirAbdul Fatir Ansari
andauthored
Return predictions in fp32 on CPU (#219)
*Issue #, if available:* N/A *Description of changes:* This PR ensures that predictions are returned in FP32 and on the CPU device. This choice is now better because we have two types of models which have different types of forecasts (samples vs. quantiles). Furthermore, `int64` input_type (our README example is one such case) ran into issues with `predict_quantiles` before. This choice also fixes that. By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice. --------- Co-authored-by: Abdul Fatir Ansari <[email protected]>
1 parent c887278 commit 4c43cfb

File tree

7 files changed

+85
-56
lines changed

7 files changed

+85
-56
lines changed

src/chronos/base.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,8 @@ def predict(
6767
**kwargs,
6868
):
6969
"""
70-
Get forecasts for the given time series.
70+
Get forecasts for the given time series. Predictions will be
71+
returned in fp32 on the cpu.
7172
7273
Parameters
7374
----------
@@ -97,6 +98,7 @@ def predict_quantiles(
9798
) -> Tuple[torch.Tensor, torch.Tensor]:
9899
"""
99100
Get quantile and mean forecasts for given time series.
101+
Predictions will be returned in fp32 on the cpu.
100102
101103
Parameters
102104
----------

src/chronos/chronos.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -500,9 +500,6 @@ def predict( # type: ignore[override]
500500
raise ValueError(msg)
501501
logger.warning(msg)
502502

503-
input_dtype = context_tensor.dtype
504-
input_device = context_tensor.device
505-
506503
predictions = []
507504
remaining = prediction_length
508505

@@ -533,7 +530,7 @@ def predict( # type: ignore[override]
533530
[context_tensor, prediction.median(dim=1).values], dim=-1
534531
)
535532

536-
return torch.cat(predictions, dim=-1).to(dtype=input_dtype, device=input_device)
533+
return torch.cat(predictions, dim=-1).to(dtype=torch.float32, device="cpu")
537534

538535
def predict_quantiles(
539536
self,

src/chronos/chronos_bolt.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -487,13 +487,14 @@ def predict( # type: ignore[override]
487487
# TODO: We unroll the forecast of Chronos Bolt greedily with the full forecast
488488
# horizon that the model was trained with (i.e., 64). This results in variance collapsing
489489
# every 64 steps.
490+
context_tensor = context_tensor.to(
491+
device=self.model.device,
492+
dtype=torch.float32,
493+
)
490494
while remaining > 0:
491495
with torch.no_grad():
492496
prediction = self.model(
493-
context=context_tensor.to(
494-
device=self.model.device,
495-
dtype=torch.float32, # scaling should be done in 32-bit precision
496-
),
497+
context=context_tensor,
497498
).quantile_preds.to(context_tensor)
498499

499500
predictions.append(prediction)
@@ -507,7 +508,9 @@ def predict( # type: ignore[override]
507508

508509
context_tensor = torch.cat([context_tensor, central_prediction], dim=-1)
509510

510-
return torch.cat(predictions, dim=-1)[..., :prediction_length]
511+
return torch.cat(predictions, dim=-1)[..., :prediction_length].to(
512+
dtype=torch.float32, device="cpu"
513+
)
511514

512515
def predict_quantiles(
513516
self,

test/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
# SPDX-License-Identifier: Apache-2.0

test/test_chronos.py

Lines changed: 28 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
# SPDX-License-Identifier: Apache-2.0
33

44
from pathlib import Path
5-
from typing import Optional, Tuple
65

76
import pytest
87
import torch
@@ -13,6 +12,7 @@
1312
ChronosPipeline,
1413
MeanScaleUniformBins,
1514
)
15+
from test.util import validate_tensor
1616

1717

1818
def test_base_chronos_pipeline_loads_from_huggingface():
@@ -166,30 +166,21 @@ def test_tokenizer_random_data(use_eos_token: bool):
166166
assert samples.shape == (2, 10, 4)
167167

168168

169-
def validate_tensor(
170-
a: torch.Tensor, shape: Tuple[int, ...], dtype: Optional[torch.dtype] = None
171-
) -> None:
172-
assert isinstance(a, torch.Tensor)
173-
assert a.shape == shape
174-
175-
if dtype is not None:
176-
assert a.dtype == dtype
177-
178-
179169
@pytest.mark.parametrize("model_dtype", [torch.float32, torch.bfloat16])
180-
@pytest.mark.parametrize("input_dtype", [torch.float32, torch.bfloat16])
170+
@pytest.mark.parametrize("input_dtype", [torch.float32, torch.bfloat16, torch.int64])
181171
def test_pipeline_predict(model_dtype: torch.dtype, input_dtype: torch.dtype):
182172
pipeline = ChronosPipeline.from_pretrained(
183173
Path(__file__).parent / "dummy-chronos-model",
184174
device_map="cpu",
185175
torch_dtype=model_dtype,
186176
)
187-
context = 10 * torch.rand(size=(4, 16), dtype=input_dtype) + 10
177+
context = 10 * torch.rand(size=(4, 16)) + 10
178+
context = context.to(dtype=input_dtype)
188179

189180
# input: tensor of shape (batch_size, context_length)
190181

191182
samples = pipeline.predict(context, num_samples=12, prediction_length=3)
192-
validate_tensor(samples, shape=(4, 12, 3), dtype=input_dtype)
183+
validate_tensor(samples, shape=(4, 12, 3), dtype=torch.float32)
193184

194185
with pytest.raises(ValueError):
195186
samples = pipeline.predict(
@@ -199,12 +190,12 @@ def test_pipeline_predict(model_dtype: torch.dtype, input_dtype: torch.dtype):
199190
samples = pipeline.predict(
200191
context, num_samples=7, prediction_length=65, limit_prediction_length=False
201192
)
202-
validate_tensor(samples, shape=(4, 7, 65), dtype=input_dtype)
193+
validate_tensor(samples, shape=(4, 7, 65), dtype=torch.float32)
203194

204195
# input: batch_size-long list of tensors of shape (context_length,)
205196

206197
samples = pipeline.predict(list(context), num_samples=12, prediction_length=3)
207-
validate_tensor(samples, shape=(4, 12, 3), dtype=input_dtype)
198+
validate_tensor(samples, shape=(4, 12, 3), dtype=torch.float32)
208199

209200
with pytest.raises(ValueError):
210201
samples = pipeline.predict(
@@ -220,12 +211,12 @@ def test_pipeline_predict(model_dtype: torch.dtype, input_dtype: torch.dtype):
220211
prediction_length=65,
221212
limit_prediction_length=False,
222213
)
223-
validate_tensor(samples, shape=(4, 7, 65), dtype=input_dtype)
214+
validate_tensor(samples, shape=(4, 7, 65), dtype=torch.float32)
224215

225216
# input: tensor of shape (context_length,)
226217

227218
samples = pipeline.predict(context[0, ...], num_samples=12, prediction_length=3)
228-
validate_tensor(samples, shape=(1, 12, 3), dtype=input_dtype)
219+
validate_tensor(samples, shape=(1, 12, 3), dtype=torch.float32)
229220

230221
with pytest.raises(ValueError):
231222
samples = pipeline.predict(
@@ -240,16 +231,18 @@ def test_pipeline_predict(model_dtype: torch.dtype, input_dtype: torch.dtype):
240231
num_samples=7,
241232
prediction_length=65,
242233
)
243-
validate_tensor(samples, shape=(1, 7, 65), dtype=input_dtype)
234+
validate_tensor(samples, shape=(1, 7, 65), dtype=torch.float32)
244235

245236

246237
@pytest.mark.parametrize("model_dtype", [torch.float32, torch.bfloat16])
238+
@pytest.mark.parametrize("input_dtype", [torch.float32, torch.bfloat16, torch.int64])
247239
@pytest.mark.parametrize("prediction_length", [3, 65])
248240
@pytest.mark.parametrize(
249241
"quantile_levels", [[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], [0.1, 0.5, 0.9]]
250242
)
251243
def test_pipeline_predict_quantiles(
252244
model_dtype: torch.dtype,
245+
input_dtype: torch.dtype,
253246
prediction_length: int,
254247
quantile_levels: list[int],
255248
):
@@ -259,6 +252,7 @@ def test_pipeline_predict_quantiles(
259252
torch_dtype=model_dtype,
260253
)
261254
context = 10 * torch.rand(size=(4, 16)) + 10
255+
context = context.to(dtype=input_dtype)
262256

263257
num_expected_quantiles = len(quantile_levels)
264258
# input: tensor of shape (batch_size, context_length)
@@ -269,8 +263,10 @@ def test_pipeline_predict_quantiles(
269263
prediction_length=prediction_length,
270264
quantile_levels=quantile_levels,
271265
)
272-
validate_tensor(quantiles, (4, prediction_length, num_expected_quantiles))
273-
validate_tensor(mean, (4, prediction_length))
266+
validate_tensor(
267+
quantiles, (4, prediction_length, num_expected_quantiles), dtype=torch.float32
268+
)
269+
validate_tensor(mean, (4, prediction_length), dtype=torch.float32)
274270

275271
# input: batch_size-long list of tensors of shape (context_length,)
276272

@@ -280,8 +276,10 @@ def test_pipeline_predict_quantiles(
280276
prediction_length=prediction_length,
281277
quantile_levels=quantile_levels,
282278
)
283-
validate_tensor(quantiles, (4, prediction_length, num_expected_quantiles))
284-
validate_tensor(mean, (4, prediction_length))
279+
validate_tensor(
280+
quantiles, (4, prediction_length, num_expected_quantiles), dtype=torch.float32
281+
)
282+
validate_tensor(mean, (4, prediction_length), dtype=torch.float32)
285283

286284
# input: tensor of shape (context_length,)
287285

@@ -291,20 +289,23 @@ def test_pipeline_predict_quantiles(
291289
prediction_length=prediction_length,
292290
quantile_levels=quantile_levels,
293291
)
294-
validate_tensor(quantiles, (1, prediction_length, num_expected_quantiles))
295-
validate_tensor(mean, (1, prediction_length))
292+
validate_tensor(
293+
quantiles, (1, prediction_length, num_expected_quantiles), dtype=torch.float32
294+
)
295+
validate_tensor(mean, (1, prediction_length), dtype=torch.float32)
296296

297297

298298
@pytest.mark.parametrize("model_dtype", [torch.float32, torch.bfloat16])
299-
@pytest.mark.parametrize("input_dtype", [torch.float32, torch.bfloat16])
299+
@pytest.mark.parametrize("input_dtype", [torch.float32, torch.bfloat16, torch.int64])
300300
def test_pipeline_embed(model_dtype: torch.dtype, input_dtype: torch.dtype):
301301
pipeline = ChronosPipeline.from_pretrained(
302302
Path(__file__).parent / "dummy-chronos-model",
303303
device_map="cpu",
304304
torch_dtype=model_dtype,
305305
)
306306
d_model = pipeline.model.model.config.d_model
307-
context = 10 * torch.rand(size=(4, 16), dtype=input_dtype) + 10
307+
context = 10 * torch.rand(size=(4, 16)) + 10
308+
context = context.to(dtype=input_dtype)
308309
expected_embed_length = 16 + (1 if pipeline.model.config.use_eos_token else 0)
309310

310311
# input: tensor of shape (batch_size, context_length)

test/test_chronos_bolt.py

Lines changed: 30 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,36 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
14
from pathlib import Path
2-
from typing import Tuple
35

46
import pytest
57
import torch
68

79
from chronos import BaseChronosPipeline, ChronosBoltPipeline
810
from chronos.chronos_bolt import InstanceNorm, Patch
9-
10-
11-
def validate_tensor(input: torch.Tensor, shape: Tuple[int, ...]) -> None:
12-
assert isinstance(input, torch.Tensor)
13-
assert input.shape == shape
11+
from test.util import validate_tensor
1412

1513

1614
def test_base_chronos_pipeline_loads_from_huggingface():
1715
BaseChronosPipeline.from_pretrained("amazon/chronos-bolt-tiny", device_map="cpu")
1816

1917

2018
@pytest.mark.parametrize("torch_dtype", [torch.float32, torch.bfloat16])
21-
def test_pipeline_predict(torch_dtype: str):
19+
@pytest.mark.parametrize("input_dtype", [torch.float32, torch.bfloat16, torch.int64])
20+
def test_pipeline_predict(torch_dtype: torch.dtype, input_dtype: torch.dtype):
2221
pipeline = ChronosBoltPipeline.from_pretrained(
2322
Path(__file__).parent / "dummy-chronos-bolt-model",
2423
device_map="cpu",
2524
torch_dtype=torch_dtype,
2625
)
2726
context = 10 * torch.rand(size=(4, 16)) + 10
27+
context = context.to(dtype=input_dtype)
2828
expected_num_quantiles = len(pipeline.quantiles)
2929

3030
# input: tensor of shape (batch_size, context_length)
3131

3232
quantiles = pipeline.predict(context, prediction_length=3)
33-
validate_tensor(quantiles, (4, expected_num_quantiles, 3))
33+
validate_tensor(quantiles, (4, expected_num_quantiles, 3), dtype=torch.float32)
3434

3535
with pytest.raises(ValueError):
3636
quantiles = pipeline.predict(
@@ -43,7 +43,7 @@ def test_pipeline_predict(torch_dtype: str):
4343
# input: batch_size-long list of tensors of shape (context_length,)
4444

4545
quantiles = pipeline.predict(list(context), prediction_length=3)
46-
validate_tensor(quantiles, (4, expected_num_quantiles, 3))
46+
validate_tensor(quantiles, (4, expected_num_quantiles, 3), dtype=torch.float32)
4747

4848
with pytest.raises(ValueError):
4949
quantiles = pipeline.predict(
@@ -53,12 +53,12 @@ def test_pipeline_predict(torch_dtype: str):
5353
)
5454

5555
quantiles = pipeline.predict(list(context), prediction_length=65)
56-
validate_tensor(quantiles, (4, expected_num_quantiles, 65))
56+
validate_tensor(quantiles, (4, expected_num_quantiles, 65), dtype=torch.float32)
5757

5858
# input: tensor of shape (context_length,)
5959

6060
quantiles = pipeline.predict(context[0, ...], prediction_length=3)
61-
validate_tensor(quantiles, (1, expected_num_quantiles, 3))
61+
validate_tensor(quantiles, (1, expected_num_quantiles, 3), dtype=torch.float32)
6262

6363
with pytest.raises(ValueError):
6464
quantiles = pipeline.predict(
@@ -71,23 +71,28 @@ def test_pipeline_predict(torch_dtype: str):
7171
context[0, ...],
7272
prediction_length=65,
7373
)
74-
validate_tensor(quantiles, (1, expected_num_quantiles, 65))
74+
validate_tensor(quantiles, (1, expected_num_quantiles, 65), dtype=torch.float32)
7575

7676

7777
@pytest.mark.parametrize("torch_dtype", [torch.float32, torch.bfloat16])
78+
@pytest.mark.parametrize("input_dtype", [torch.float32, torch.bfloat16, torch.int64])
7879
@pytest.mark.parametrize("prediction_length", [3, 65])
7980
@pytest.mark.parametrize(
8081
"quantile_levels", [[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], [0.1, 0.5, 0.9]]
8182
)
8283
def test_pipeline_predict_quantiles(
83-
torch_dtype: str, prediction_length: int, quantile_levels: list[int]
84+
torch_dtype: torch.dtype,
85+
input_dtype: torch.dtype,
86+
prediction_length: int,
87+
quantile_levels: list[int],
8488
):
8589
pipeline = ChronosBoltPipeline.from_pretrained(
8690
Path(__file__).parent / "dummy-chronos-bolt-model",
8791
device_map="cpu",
8892
torch_dtype=torch_dtype,
8993
)
9094
context = 10 * torch.rand(size=(4, 16)) + 10
95+
context = context.to(dtype=input_dtype)
9196

9297
num_expected_quantiles = len(quantile_levels)
9398
# input: tensor of shape (batch_size, context_length)
@@ -97,8 +102,10 @@ def test_pipeline_predict_quantiles(
97102
prediction_length=prediction_length,
98103
quantile_levels=quantile_levels,
99104
)
100-
validate_tensor(quantiles, (4, prediction_length, num_expected_quantiles))
101-
validate_tensor(mean, (4, prediction_length))
105+
validate_tensor(
106+
quantiles, (4, prediction_length, num_expected_quantiles), dtype=torch.float32
107+
)
108+
validate_tensor(mean, (4, prediction_length), dtype=torch.float32)
102109

103110
# input: batch_size-long list of tensors of shape (context_length,)
104111

@@ -107,8 +114,10 @@ def test_pipeline_predict_quantiles(
107114
prediction_length=prediction_length,
108115
quantile_levels=quantile_levels,
109116
)
110-
validate_tensor(quantiles, (4, prediction_length, num_expected_quantiles))
111-
validate_tensor(mean, (4, prediction_length))
117+
validate_tensor(
118+
quantiles, (4, prediction_length, num_expected_quantiles), dtype=torch.float32
119+
)
120+
validate_tensor(mean, (4, prediction_length), dtype=torch.float32)
112121

113122
# input: tensor of shape (context_length,)
114123

@@ -117,8 +126,10 @@ def test_pipeline_predict_quantiles(
117126
prediction_length=prediction_length,
118127
quantile_levels=quantile_levels,
119128
)
120-
validate_tensor(quantiles, (1, prediction_length, num_expected_quantiles))
121-
validate_tensor(mean, (1, prediction_length))
129+
validate_tensor(
130+
quantiles, (1, prediction_length, num_expected_quantiles), dtype=torch.float32
131+
)
132+
validate_tensor(mean, (1, prediction_length), dtype=torch.float32)
122133

123134

124135
# The following tests have been taken from

test/util.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from typing import Optional, Tuple
2+
3+
import torch
4+
5+
6+
def validate_tensor(
7+
a: torch.Tensor, shape: Tuple[int, ...], dtype: Optional[torch.dtype] = None
8+
) -> None:
9+
assert isinstance(a, torch.Tensor)
10+
assert a.shape == shape
11+
12+
if dtype is not None:
13+
assert a.dtype == dtype

0 commit comments

Comments
 (0)