Skip to content

Commit 78159d6

Browse files
authored
Extend the supported types of decodePNG (#2984)
* Add support of different color types in readpng. * Adding test images and unit-tests. * Use closest possible type. * Fix formatting.
1 parent 481ef51 commit 78159d6

File tree

7 files changed

+31
-10
lines changed

7 files changed

+31
-10
lines changed
433 Bytes
Loading
590 Bytes
Loading
1.12 KB
Loading
575 Bytes
Loading
1.12 KB
Loading

test/test_image.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616

1717

1818
IMAGE_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets")
19-
IMAGE_DIR = os.path.join(IMAGE_ROOT, "fakedata", "imagefolder")
19+
FAKEDATA_DIR = os.path.join(IMAGE_ROOT, "fakedata")
20+
IMAGE_DIR = os.path.join(FAKEDATA_DIR, "imagefolder")
2021
DAMAGED_JPEG = os.path.join(IMAGE_ROOT, 'damaged_jpeg')
2122

2223

@@ -133,9 +134,12 @@ def test_write_jpeg(self):
133134
self.assertEqual(torch_bytes, pil_bytes)
134135

135136
def test_decode_png(self):
136-
for img_path in get_images(IMAGE_DIR, ".png"):
137+
for img_path in get_images(FAKEDATA_DIR, ".png"):
137138
img_pil = torch.from_numpy(np.array(Image.open(img_path)))
138-
img_pil = img_pil.permute(2, 0, 1)
139+
if len(img_pil.shape) == 3:
140+
img_pil = img_pil.permute(2, 0, 1)
141+
else:
142+
img_pil = img_pil.unsqueeze(0)
139143
data = read_file(img_path)
140144
img_lpng = decode_png(data)
141145
self.assertTrue(img_lpng.equal(img_pil))

torchvision/csrc/cpu/image/readpng_cpu.cpp

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -71,17 +71,34 @@ torch::Tensor decodePNG(const torch::Tensor& data) {
7171
png_destroy_read_struct(&png_ptr, &info_ptr, nullptr);
7272
TORCH_CHECK(retval == 1, "Could read image metadata from content.")
7373
}
74-
if (color_type != PNG_COLOR_TYPE_RGB) {
75-
png_destroy_read_struct(&png_ptr, &info_ptr, nullptr);
76-
TORCH_CHECK(
77-
color_type == PNG_COLOR_TYPE_RGB, "Non RGB images are not supported.")
74+
75+
int channels;
76+
switch (color_type) {
77+
case PNG_COLOR_TYPE_RGB:
78+
channels = 3;
79+
break;
80+
case PNG_COLOR_TYPE_RGB_ALPHA:
81+
channels = 4;
82+
break;
83+
case PNG_COLOR_TYPE_GRAY:
84+
channels = 1;
85+
break;
86+
case PNG_COLOR_TYPE_GRAY_ALPHA:
87+
channels = 2;
88+
break;
89+
case PNG_COLOR_TYPE_PALETTE:
90+
channels = 1;
91+
break;
92+
default:
93+
png_destroy_read_struct(&png_ptr, &info_ptr, nullptr);
94+
TORCH_CHECK(false, "Image color type is not supported.");
7895
}
7996

80-
auto tensor =
81-
torch::empty({int64_t(height), int64_t(width), int64_t(3)}, torch::kU8);
97+
auto tensor = torch::empty(
98+
{int64_t(height), int64_t(width), int64_t(channels)}, torch::kU8);
8299
auto ptr = tensor.accessor<uint8_t, 3>().data();
83100
auto bytes = png_get_rowbytes(png_ptr, info_ptr);
84-
for (decltype(height) i = 0; i < height; ++i) {
101+
for (png_uint_32 i = 0; i < height; ++i) {
85102
png_read_row(png_ptr, ptr, nullptr);
86103
ptr += bytes;
87104
}

0 commit comments

Comments
 (0)