Skip to content

Commit 67f0084

Browse files
authored
Fix padding for int contexts (#227)
*Issue #, if available:* On Linux, the final call to `.to` creates trouble when input tensors are integer. For example: ``` >>> a = torch.tensor([1]) >>> b = torch.stack([torch.full((1,), torch.nan), a]) >>> b tensor([[nan], [1.]]) >>> b.to(a) tensor([[-9223372036854775808], [ 1]]) ``` By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.
1 parent 47cac08 commit 67f0084

File tree

4 files changed

+32
-3
lines changed

4 files changed

+32
-3
lines changed

src/chronos/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,4 @@ def left_pad_and_stack_1D(tensors: List[torch.Tensor]) -> torch.Tensor:
1717
size=(max_len - len(c),), fill_value=torch.nan, device=c.device
1818
)
1919
padded.append(torch.concat((padding, c), dim=-1))
20-
return torch.stack(padded).to(tensors[0])
20+
return torch.stack(padded)

test/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2-
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-License-Identifier: Apache-2.0

test/test_utils.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
import pytest
5+
import torch
6+
7+
from chronos.utils import left_pad_and_stack_1D
8+
9+
10+
@pytest.mark.parametrize(
11+
"tensors",
12+
[
13+
[
14+
torch.tensor([2.0, 3.0], dtype=dtype),
15+
torch.tensor([4.0, 5.0, 6.0], dtype=dtype),
16+
torch.tensor([7.0, 8.0, 9.0, 10.0], dtype=dtype),
17+
]
18+
for dtype in [torch.int, torch.float16, torch.float32]
19+
],
20+
)
21+
def test_pad_and_stack(tensors: list):
22+
stacked_and_padded = left_pad_and_stack_1D(tensors)
23+
24+
assert stacked_and_padded.dtype == torch.float32
25+
assert stacked_and_padded.shape == (len(tensors), max(len(t) for t in tensors))
26+
27+
ref = torch.concat(tensors).to(dtype=stacked_and_padded.dtype)
28+
29+
assert torch.sum(torch.nan_to_num(stacked_and_padded, nan=0)) == torch.sum(ref)

test/util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,4 @@ def validate_tensor(
1010
assert a.shape == shape
1111

1212
if dtype is not None:
13-
assert a.dtype == dtype
13+
assert a.dtype == dtype

0 commit comments

Comments
 (0)