Skip to content

Commit caf3ca8

Browse files
fmassavfdev-5
authored andcommitted
Add decode_image op (pytorch#2718)
* Add decode_image op * Fix lint * More lint * Add C10_EXPORT
1 parent 423296e commit caf3ca8

File tree

8 files changed

+112
-33
lines changed

8 files changed

+112
-33
lines changed

test/test_image.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import torchvision
99
from PIL import Image
1010
from torchvision.io.image import (
11-
read_png, decode_png, read_jpeg, decode_jpeg, encode_jpeg, write_jpeg)
11+
read_png, decode_png, read_jpeg, decode_jpeg, encode_jpeg, write_jpeg, decode_image, _read_file)
1212
import numpy as np
1313

1414
IMAGE_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets")
@@ -44,10 +44,10 @@ def test_decode_jpeg(self):
4444
img_ljpeg = decode_jpeg(torch.from_file(img_path, dtype=torch.uint8, size=size))
4545
self.assertTrue(img_ljpeg.equal(img_pil))
4646

47-
with self.assertRaisesRegex(ValueError, "Expected a non empty 1-dimensional tensor."):
47+
with self.assertRaisesRegex(RuntimeError, "Expected a non empty 1-dimensional tensor"):
4848
decode_jpeg(torch.empty((100, 1), dtype=torch.uint8))
4949

50-
with self.assertRaisesRegex(ValueError, "Expected a torch.uint8 tensor."):
50+
with self.assertRaisesRegex(RuntimeError, "Expected a torch.uint8 tensor"):
5151
decode_jpeg(torch.empty((100, ), dtype=torch.float16))
5252

5353
with self.assertRaises(RuntimeError):
@@ -149,11 +149,24 @@ def test_decode_png(self):
149149
img_lpng = decode_png(torch.from_file(img_path, dtype=torch.uint8, size=size))
150150
self.assertTrue(img_lpng.equal(img_pil))
151151

152-
with self.assertRaises(ValueError):
152+
with self.assertRaises(RuntimeError):
153153
decode_png(torch.empty((), dtype=torch.uint8))
154154
with self.assertRaises(RuntimeError):
155155
decode_png(torch.randint(3, 5, (300,), dtype=torch.uint8))
156156

157+
def test_decode_image(self):
158+
for img_path in get_images(IMAGE_ROOT, ".jpg"):
159+
img_pil = torch.load(img_path.replace('jpg', 'pth'))
160+
img_pil = img_pil.permute(2, 0, 1)
161+
img_ljpeg = decode_image(_read_file(img_path))
162+
self.assertTrue(img_ljpeg.equal(img_pil))
163+
164+
for img_path in get_images(IMAGE_DIR, ".png"):
165+
img_pil = torch.from_numpy(np.array(Image.open(img_path)))
166+
img_pil = img_pil.permute(2, 0, 1)
167+
img_lpng = decode_image(_read_file(img_path))
168+
self.assertTrue(img_lpng.equal(img_pil))
169+
157170

158171
if __name__ == '__main__':
159172
unittest.main()

torchvision/csrc/cpu/image/image.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,5 @@ static auto registry = torch::RegisterOperators()
1616
.op("image::decode_png", &decodePNG)
1717
.op("image::decode_jpeg", &decodeJPEG)
1818
.op("image::encode_jpeg", &encodeJPEG)
19-
.op("image::write_jpeg", &writeJPEG);
19+
.op("image::write_jpeg", &writeJPEG)
20+
.op("image::decode_image", &decode_image);

torchvision/csrc/cpu/image/image.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
21
#pragma once
32

43
// Comment
54
#include <torch/script.h>
65
#include <torch/torch.h>
6+
#include "read_image_cpu.h"
77
#include "readjpeg_cpu.h"
88
#include "readpng_cpu.h"
99
#include "writejpeg_cpu.h"
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
#include "read_image_cpu.h"
2+
#include <string.h>
3+
4+
torch::Tensor decode_image(const torch::Tensor& data) {
5+
// Check that the input tensor dtype is uint8
6+
TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor");
7+
// Check that the input tensor is 1-dimensional
8+
TORCH_CHECK(
9+
data.dim() == 1 && data.numel() > 0,
10+
"Expected a non empty 1-dimensional tensor");
11+
12+
auto datap = data.data_ptr<uint8_t>();
13+
14+
const uint8_t jpeg_signature[3] = {255, 216, 255}; // == "\xFF\xD8\xFF"
15+
const uint8_t png_signature[4] = {137, 80, 78, 71}; // == "\211PNG"
16+
17+
if (memcmp(jpeg_signature, datap, 3) == 0) {
18+
return decodeJPEG(data);
19+
} else if (memcmp(png_signature, datap, 4) == 0) {
20+
return decodePNG(data);
21+
} else {
22+
TORCH_CHECK(
23+
false,
24+
"Unsupported image file. Only jpeg and png ",
25+
"are currently supported.");
26+
}
27+
}
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
#pragma once
2+
3+
#include "readjpeg_cpu.h"
4+
#include "readpng_cpu.h"
5+
6+
C10_EXPORT torch::Tensor decode_image(const torch::Tensor& data);

torchvision/csrc/cpu/image/readjpeg_cpu.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,13 @@ static void torch_jpeg_set_source_mgr(
7272
}
7373

7474
torch::Tensor decodeJPEG(const torch::Tensor& data) {
75+
// Check that the input tensor dtype is uint8
76+
TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor");
77+
// Check that the input tensor is 1-dimensional
78+
TORCH_CHECK(
79+
data.dim() == 1 && data.numel() > 0,
80+
"Expected a non empty 1-dimensional tensor");
81+
7582
struct jpeg_decompress_struct cinfo;
7683
struct torch_jpeg_error_mgr jerr;
7784

torchvision/csrc/cpu/image/readpng_cpu.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,13 @@ torch::Tensor decodePNG(const torch::Tensor& data) {
1313
#include <png.h>
1414

1515
torch::Tensor decodePNG(const torch::Tensor& data) {
16+
// Check that the input tensor dtype is uint8
17+
TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor");
18+
// Check that the input tensor is 1-dimensional
19+
TORCH_CHECK(
20+
data.dim() == 1 && data.numel() > 0,
21+
"Expected a non empty 1-dimensional tensor");
22+
1623
auto png_ptr =
1724
png_create_read_struct(PNG_LIBPNG_VER_STRING, nullptr, nullptr, nullptr);
1825
TORCH_CHECK(png_ptr, "libpng read structure allocation failed!")

torchvision/io/image.py

Lines changed: 45 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -23,23 +23,29 @@
2323
pass
2424

2525

26+
def _read_file(path: str) -> torch.Tensor:
27+
if not os.path.isfile(path):
28+
raise ValueError("Expected a valid file path.")
29+
30+
size = os.path.getsize(path)
31+
if size == 0:
32+
raise ValueError("Expected a non empty file.")
33+
data = torch.from_file(path, dtype=torch.uint8, size=size)
34+
return data
35+
36+
2637
def decode_png(input: torch.Tensor) -> torch.Tensor:
2738
"""
2839
Decodes a PNG image into a 3 dimensional RGB Tensor.
2940
The values of the output tensor are uint8 between 0 and 255.
3041
3142
Arguments:
32-
input (Tensor[1]): a one dimensional int8 tensor containing
43+
input (Tensor[1]): a one dimensional uint8 tensor containing
3344
the raw bytes of the PNG image.
3445
3546
Returns:
3647
output (Tensor[3, image_height, image_width])
3748
"""
38-
if not isinstance(input, torch.Tensor) or input.numel() == 0 or input.ndim != 1: # type: ignore[attr-defined]
39-
raise ValueError("Expected a non empty 1-dimensional tensor.")
40-
41-
if not input.dtype == torch.uint8:
42-
raise ValueError("Expected a torch.uint8 tensor.")
4349
output = torch.ops.image.decode_png(input)
4450
return output
4551

@@ -55,13 +61,7 @@ def read_png(path: str) -> torch.Tensor:
5561
Returns:
5662
output (Tensor[3, image_height, image_width])
5763
"""
58-
if not os.path.isfile(path):
59-
raise ValueError("Expected a valid file path.")
60-
61-
size = os.path.getsize(path)
62-
if size == 0:
63-
raise ValueError("Expected a non empty file.")
64-
data = torch.from_file(path, dtype=torch.uint8, size=size)
64+
data = _read_file(path)
6565
return decode_png(data)
6666

6767

@@ -70,17 +70,11 @@ def decode_jpeg(input: torch.Tensor) -> torch.Tensor:
7070
Decodes a JPEG image into a 3 dimensional RGB Tensor.
7171
The values of the output tensor are uint8 between 0 and 255.
7272
Arguments:
73-
input (Tensor[1]): a one dimensional int8 tensor containing
73+
input (Tensor[1]): a one dimensional uint8 tensor containing
7474
the raw bytes of the JPEG image.
7575
Returns:
7676
output (Tensor[3, image_height, image_width])
7777
"""
78-
if not isinstance(input, torch.Tensor) or len(input) == 0 or input.ndim != 1: # type: ignore[attr-defined]
79-
raise ValueError("Expected a non empty 1-dimensional tensor.")
80-
81-
if not input.dtype == torch.uint8:
82-
raise ValueError("Expected a torch.uint8 tensor.")
83-
8478
output = torch.ops.image.decode_jpeg(input)
8579
return output
8680

@@ -94,13 +88,7 @@ def read_jpeg(path: str) -> torch.Tensor:
9488
Returns:
9589
output (Tensor[3, image_height, image_width])
9690
"""
97-
if not os.path.isfile(path):
98-
raise ValueError("Expected a valid file path.")
99-
100-
size = os.path.getsize(path)
101-
if size == 0:
102-
raise ValueError("Expected a non empty file.")
103-
data = torch.from_file(path, dtype=torch.uint8, size=size)
91+
data = _read_file(path)
10492
return decode_jpeg(data)
10593

10694

@@ -141,3 +129,33 @@ def write_jpeg(input: torch.Tensor, filename: str, quality: int = 75):
141129
'between 1 and 100')
142130

143131
torch.ops.image.write_jpeg(input, filename, quality)
132+
133+
134+
def decode_image(input: torch.Tensor) -> torch.Tensor:
135+
"""
136+
Detects whether an image is a JPEG or PNG and performs the appropriate
137+
operation to decode the image into a 3 dimensional RGB Tensor.
138+
139+
The values of the output tensor are uint8 between 0 and 255.
140+
141+
Arguments:
142+
input (Tensor): a one dimensional uint8 tensor containing
143+
the raw bytes of the PNG or JPEG image.
144+
Returns:
145+
output (Tensor[3, image_height, image_width])
146+
"""
147+
output = torch.ops.image.decode_image(input)
148+
return output
149+
150+
151+
def read_image(path: str) -> torch.Tensor:
152+
"""
153+
Reads a JPEG or PNG image into a 3 dimensional RGB Tensor.
154+
The values of the output tensor are uint8 between 0 and 255.
155+
Arguments:
156+
path (str): path of the JPEG or PNG image.
157+
Returns:
158+
output (Tensor[3, image_height, image_width])
159+
"""
160+
data = _read_file(path)
161+
return decode_image(data)

0 commit comments

Comments
 (0)