Skip to content

Commit 63b59ba

Browse files
NicolasHugfacebook-github-bot
authored andcommitted
[fbsync] Added support for CMYK in decode_jpeg (#7741)
Summary: (Note: this ignores all push blocking failures!) Reviewed By: matteobettini Differential Revision: D48900406 fbshipit-source-id: 05a9d0086cf88677e2249c96638252c8ab92d637 Co-authored-by: Nicolas Hug <[email protected]>
1 parent dfdfd88 commit 63b59ba

File tree

2 files changed

+83
-9
lines changed

2 files changed

+83
-9
lines changed

test/test_image.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,12 +83,9 @@ def test_decode_jpeg(img_path, pil_mode, mode):
8383
with Image.open(img_path) as img:
8484
is_cmyk = img.mode == "CMYK"
8585
if pil_mode is not None:
86-
if is_cmyk:
87-
# libjpeg does not support the conversion
88-
pytest.xfail("Decoding a CMYK jpeg isn't supported")
8986
img = img.convert(pil_mode)
9087
img_pil = torch.from_numpy(np.array(img))
91-
if is_cmyk:
88+
if is_cmyk and mode == ImageReadMode.UNCHANGED:
9289
# flip the colors to match libjpeg
9390
img_pil = 255 - img_pil
9491

torchvision/csrc/io/image/cpu/decode_jpeg.cpp

Lines changed: 82 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,58 @@ static void torch_jpeg_set_source_mgr(
6767
src->pub.next_input_byte = src->data;
6868
}
6969

70+
inline unsigned char clamped_cmyk_rgb_convert(
71+
unsigned char k,
72+
unsigned char cmy) {
73+
// Inspired from Pillow:
74+
// https://github.com/python-pillow/Pillow/blob/07623d1a7cc65206a5355fba2ae256550bfcaba6/src/libImaging/Convert.c#L568-L569
75+
int v = k * cmy + 128;
76+
v = ((v >> 8) + v) >> 8;
77+
return std::clamp(k - v, 0, 255);
78+
}
79+
80+
void convert_line_cmyk_to_rgb(
81+
j_decompress_ptr cinfo,
82+
const unsigned char* cmyk_line,
83+
unsigned char* rgb_line) {
84+
int width = cinfo->output_width;
85+
for (int i = 0; i < width; ++i) {
86+
int c = cmyk_line[i * 4 + 0];
87+
int m = cmyk_line[i * 4 + 1];
88+
int y = cmyk_line[i * 4 + 2];
89+
int k = cmyk_line[i * 4 + 3];
90+
91+
rgb_line[i * 3 + 0] = clamped_cmyk_rgb_convert(k, 255 - c);
92+
rgb_line[i * 3 + 1] = clamped_cmyk_rgb_convert(k, 255 - m);
93+
rgb_line[i * 3 + 2] = clamped_cmyk_rgb_convert(k, 255 - y);
94+
}
95+
}
96+
97+
inline unsigned char rgb_to_gray(int r, int g, int b) {
98+
// Inspired from Pillow:
99+
// https://github.com/python-pillow/Pillow/blob/07623d1a7cc65206a5355fba2ae256550bfcaba6/src/libImaging/Convert.c#L226
100+
return (r * 19595 + g * 38470 + b * 7471 + 0x8000) >> 16;
101+
}
102+
103+
void convert_line_cmyk_to_gray(
104+
j_decompress_ptr cinfo,
105+
const unsigned char* cmyk_line,
106+
unsigned char* gray_line) {
107+
int width = cinfo->output_width;
108+
for (int i = 0; i < width; ++i) {
109+
int c = cmyk_line[i * 4 + 0];
110+
int m = cmyk_line[i * 4 + 1];
111+
int y = cmyk_line[i * 4 + 2];
112+
int k = cmyk_line[i * 4 + 3];
113+
114+
int r = clamped_cmyk_rgb_convert(k, 255 - c);
115+
int g = clamped_cmyk_rgb_convert(k, 255 - m);
116+
int b = clamped_cmyk_rgb_convert(k, 255 - y);
117+
118+
gray_line[i] = rgb_to_gray(r, g, b);
119+
}
120+
}
121+
70122
} // namespace
71123

72124
torch::Tensor decode_jpeg(const torch::Tensor& data, ImageReadMode mode) {
@@ -102,20 +154,29 @@ torch::Tensor decode_jpeg(const torch::Tensor& data, ImageReadMode mode) {
102154
jpeg_read_header(&cinfo, TRUE);
103155

104156
int channels = cinfo.num_components;
157+
bool cmyk_to_rgb_or_gray = false;
105158

106159
if (mode != IMAGE_READ_MODE_UNCHANGED) {
107160
switch (mode) {
108161
case IMAGE_READ_MODE_GRAY:
109-
if (cinfo.jpeg_color_space != JCS_GRAYSCALE) {
162+
if (cinfo.jpeg_color_space == JCS_CMYK ||
163+
cinfo.jpeg_color_space == JCS_YCCK) {
164+
cinfo.out_color_space = JCS_CMYK;
165+
cmyk_to_rgb_or_gray = true;
166+
} else {
110167
cinfo.out_color_space = JCS_GRAYSCALE;
111-
channels = 1;
112168
}
169+
channels = 1;
113170
break;
114171
case IMAGE_READ_MODE_RGB:
115-
if (cinfo.jpeg_color_space != JCS_RGB) {
172+
if (cinfo.jpeg_color_space == JCS_CMYK ||
173+
cinfo.jpeg_color_space == JCS_YCCK) {
174+
cinfo.out_color_space = JCS_CMYK;
175+
cmyk_to_rgb_or_gray = true;
176+
} else {
116177
cinfo.out_color_space = JCS_RGB;
117-
channels = 3;
118178
}
179+
channels = 3;
119180
break;
120181
/*
121182
* Libjpeg does not support converting from CMYK to grayscale etc. There
@@ -139,12 +200,28 @@ torch::Tensor decode_jpeg(const torch::Tensor& data, ImageReadMode mode) {
139200
auto tensor =
140201
torch::empty({int64_t(height), int64_t(width), channels}, torch::kU8);
141202
auto ptr = tensor.data_ptr<uint8_t>();
203+
torch::Tensor cmyk_line_tensor;
204+
if (cmyk_to_rgb_or_gray) {
205+
cmyk_line_tensor = torch::empty({int64_t(width), 4}, torch::kU8);
206+
}
207+
142208
while (cinfo.output_scanline < cinfo.output_height) {
143209
/* jpeg_read_scanlines expects an array of pointers to scanlines.
144210
* Here the array is only one element long, but you could ask for
145211
* more than one scanline at a time if that's more convenient.
146212
*/
147-
jpeg_read_scanlines(&cinfo, &ptr, 1);
213+
if (cmyk_to_rgb_or_gray) {
214+
auto cmyk_line_ptr = cmyk_line_tensor.data_ptr<uint8_t>();
215+
jpeg_read_scanlines(&cinfo, &cmyk_line_ptr, 1);
216+
217+
if (channels == 3) {
218+
convert_line_cmyk_to_rgb(&cinfo, cmyk_line_ptr, ptr);
219+
} else if (channels == 1) {
220+
convert_line_cmyk_to_gray(&cinfo, cmyk_line_ptr, ptr);
221+
}
222+
} else {
223+
jpeg_read_scanlines(&cinfo, &ptr, 1);
224+
}
148225
ptr += stride;
149226
}
150227

0 commit comments

Comments
 (0)