@@ -67,6 +67,58 @@ static void torch_jpeg_set_source_mgr(
67
67
src->pub .next_input_byte = src->data ;
68
68
}
69
69
70
+ inline unsigned char clamped_cmyk_rgb_convert (
71
+ unsigned char k,
72
+ unsigned char cmy) {
73
+ // Inspired from Pillow:
74
+ // https://github.com/python-pillow/Pillow/blob/07623d1a7cc65206a5355fba2ae256550bfcaba6/src/libImaging/Convert.c#L568-L569
75
+ int v = k * cmy + 128 ;
76
+ v = ((v >> 8 ) + v) >> 8 ;
77
+ return std::clamp (k - v, 0 , 255 );
78
+ }
79
+
80
+ void convert_line_cmyk_to_rgb (
81
+ j_decompress_ptr cinfo,
82
+ const unsigned char * cmyk_line,
83
+ unsigned char * rgb_line) {
84
+ int width = cinfo->output_width ;
85
+ for (int i = 0 ; i < width; ++i) {
86
+ int c = cmyk_line[i * 4 + 0 ];
87
+ int m = cmyk_line[i * 4 + 1 ];
88
+ int y = cmyk_line[i * 4 + 2 ];
89
+ int k = cmyk_line[i * 4 + 3 ];
90
+
91
+ rgb_line[i * 3 + 0 ] = clamped_cmyk_rgb_convert (k, 255 - c);
92
+ rgb_line[i * 3 + 1 ] = clamped_cmyk_rgb_convert (k, 255 - m);
93
+ rgb_line[i * 3 + 2 ] = clamped_cmyk_rgb_convert (k, 255 - y);
94
+ }
95
+ }
96
+
97
+ inline unsigned char rgb_to_gray (int r, int g, int b) {
98
+ // Inspired from Pillow:
99
+ // https://github.com/python-pillow/Pillow/blob/07623d1a7cc65206a5355fba2ae256550bfcaba6/src/libImaging/Convert.c#L226
100
+ return (r * 19595 + g * 38470 + b * 7471 + 0x8000 ) >> 16 ;
101
+ }
102
+
103
+ void convert_line_cmyk_to_gray (
104
+ j_decompress_ptr cinfo,
105
+ const unsigned char * cmyk_line,
106
+ unsigned char * gray_line) {
107
+ int width = cinfo->output_width ;
108
+ for (int i = 0 ; i < width; ++i) {
109
+ int c = cmyk_line[i * 4 + 0 ];
110
+ int m = cmyk_line[i * 4 + 1 ];
111
+ int y = cmyk_line[i * 4 + 2 ];
112
+ int k = cmyk_line[i * 4 + 3 ];
113
+
114
+ int r = clamped_cmyk_rgb_convert (k, 255 - c);
115
+ int g = clamped_cmyk_rgb_convert (k, 255 - m);
116
+ int b = clamped_cmyk_rgb_convert (k, 255 - y);
117
+
118
+ gray_line[i] = rgb_to_gray (r, g, b);
119
+ }
120
+ }
121
+
70
122
} // namespace
71
123
72
124
torch::Tensor decode_jpeg (const torch::Tensor& data, ImageReadMode mode) {
@@ -102,20 +154,29 @@ torch::Tensor decode_jpeg(const torch::Tensor& data, ImageReadMode mode) {
102
154
jpeg_read_header (&cinfo, TRUE );
103
155
104
156
int channels = cinfo.num_components ;
157
+ bool cmyk_to_rgb_or_gray = false ;
105
158
106
159
if (mode != IMAGE_READ_MODE_UNCHANGED) {
107
160
switch (mode) {
108
161
case IMAGE_READ_MODE_GRAY:
109
- if (cinfo.jpeg_color_space != JCS_GRAYSCALE) {
162
+ if (cinfo.jpeg_color_space == JCS_CMYK ||
163
+ cinfo.jpeg_color_space == JCS_YCCK) {
164
+ cinfo.out_color_space = JCS_CMYK;
165
+ cmyk_to_rgb_or_gray = true ;
166
+ } else {
110
167
cinfo.out_color_space = JCS_GRAYSCALE;
111
- channels = 1 ;
112
168
}
169
+ channels = 1 ;
113
170
break ;
114
171
case IMAGE_READ_MODE_RGB:
115
- if (cinfo.jpeg_color_space != JCS_RGB) {
172
+ if (cinfo.jpeg_color_space == JCS_CMYK ||
173
+ cinfo.jpeg_color_space == JCS_YCCK) {
174
+ cinfo.out_color_space = JCS_CMYK;
175
+ cmyk_to_rgb_or_gray = true ;
176
+ } else {
116
177
cinfo.out_color_space = JCS_RGB;
117
- channels = 3 ;
118
178
}
179
+ channels = 3 ;
119
180
break ;
120
181
/*
121
182
* Libjpeg does not support converting from CMYK to grayscale etc. There
@@ -139,12 +200,28 @@ torch::Tensor decode_jpeg(const torch::Tensor& data, ImageReadMode mode) {
139
200
auto tensor =
140
201
torch::empty ({int64_t (height), int64_t (width), channels}, torch::kU8 );
141
202
auto ptr = tensor.data_ptr <uint8_t >();
203
+ torch::Tensor cmyk_line_tensor;
204
+ if (cmyk_to_rgb_or_gray) {
205
+ cmyk_line_tensor = torch::empty ({int64_t (width), 4 }, torch::kU8 );
206
+ }
207
+
142
208
while (cinfo.output_scanline < cinfo.output_height ) {
143
209
/* jpeg_read_scanlines expects an array of pointers to scanlines.
144
210
* Here the array is only one element long, but you could ask for
145
211
* more than one scanline at a time if that's more convenient.
146
212
*/
147
- jpeg_read_scanlines (&cinfo, &ptr, 1 );
213
+ if (cmyk_to_rgb_or_gray) {
214
+ auto cmyk_line_ptr = cmyk_line_tensor.data_ptr <uint8_t >();
215
+ jpeg_read_scanlines (&cinfo, &cmyk_line_ptr, 1 );
216
+
217
+ if (channels == 3 ) {
218
+ convert_line_cmyk_to_rgb (&cinfo, cmyk_line_ptr, ptr);
219
+ } else if (channels == 1 ) {
220
+ convert_line_cmyk_to_gray (&cinfo, cmyk_line_ptr, ptr);
221
+ }
222
+ } else {
223
+ jpeg_read_scanlines (&cinfo, &ptr, 1 );
224
+ }
148
225
ptr += stride;
149
226
}
150
227
0 commit comments