Skip to content

Adding GPU acceleration to encode_jpeg #8375

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 0 commits into from
Closed

Conversation

deekay42
Copy link
Contributor

@deekay42 deekay42 commented Apr 12, 2024

Summary:
I'm adding GPU support to the existing torchvision.io.encode_jpeg function. If the input tensors are on the GPU, the CUDA version will be used and the CPU version otherwise.
Additionally, I'm adding a new function torchvision.io.encode_jpegs (plural) with uses a fused kernel and may be faster than successive calls to the singular version which incurs kernel launch overhead for each call.
If it's alright, I'll be happy to refactor decode_jpeg to follow this convention in a follow up PR.

Test Plan:

  1. pytest test -vvv
  2. ufmt format torchvision
  3. flake8 torchvision

Reviewers:
@NicolasHug

Copy link

pytorch-bot bot commented Apr 12, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/vision/8375

Note: Links to docs will display an error until the docs builds have been completed.

❌ 4 New Failures, 37 Pending

As of commit 6c1c9fe with merge base 5181a85 (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Copy link
Member

@NicolasHug NicolasHug left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for submitting the PR @deekay42 .

Before going deeper into the reviews we should try to make sure the existing tests are running / passing.

It looks like there are some build failures in various jobs. You can see the CI job either

There are a lot of failures, some which are more important than others. I made comments below which correspond to some of those failures.

The failure that seem the most concerning at this stage is the one that looks like

2024-04-15T09:56:21.6662102Z AttributeError: '_OpNamespace' 'image' object has no attribute '_jpeg_version'

It comes from here and the failure suggests that the image extension (where you added the new decoder) doesn't build properly with this PR.

How have you been testing those changes so far? Was it on a devvm/fbcode (with buck etc.)? If yes, it might be why you haven't seen those failure happen locally yet?

return output


def encode_jpegs(inputs: list[torch.Tensor], quality: int = 75) -> list[torch.Tensor]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here and everywhere else, this should be List
otherwise this is causing errors like

2024-04-15T10:06:23.0758306Z     def encode_jpegs(inputs: list[torch.Tensor], quality: int = 75) -> list[torch.Tensor]:
2024-04-15T10:06:23.0759045Z TypeError: 'type' object is not subscriptable
Suggested change
def encode_jpegs(inputs: list[torch.Tensor], quality: int = 75) -> list[torch.Tensor]:
def encode_jpegs(inputs: List[torch.Tensor], quality: int = 75) -> List[torch.Tensor]:

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey Nicolas, thanks for taking a look!
I used a devGPU and didn't use buck or any internal tools. The PR builds and tests fine locally without any errors. I'm guessing this is because I'm using Python 3.11, but this test uses 3.8 (subscript for standard containers was introducted in PEP 585 which was released with python 3.9). Should be an easy fix.

@deekay42 deekay42 marked this pull request as draft April 15, 2024 18:31
@deekay42 deekay42 force-pushed the main branch 2 times, most recently from 1033dc4 to 7181f4c Compare April 16, 2024 20:56
Copy link
Member

@NicolasHug NicolasHug left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @deekay42 , I made a first pass over the code. This looks pretty good already!

@@ -483,6 +559,53 @@ def test_encode_jpeg_errors():
with pytest.raises(RuntimeError, match="Input data should be a 3-dimensional tensor"):
encode_jpeg(torch.empty((100, 100), dtype=torch.uint8))

@needs_cuda
def test_encode_jpeg_cuda_errors():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test_encode_jpeg_cuda_errors and test_encode_jpegs_cuda_errors are currently sandwiched between test_encode_jpeg_errors and test_encode_jpeg which isn't great as it scatters related tests (we're not necessarily doing a great job at that in that file already).

But it might best to keep test_encode_jpeg_cuda_errors and test_encode_jpegs_cuda_errors just below test_encode_jpeg[s]_cuda to have related tests grouped together?

Comment on lines 91 to 100
# next we check to see if the image is encodeable by PIL and
# if it isn't we don't include it in the test
try:
decoded_image_pil = F.to_pil_image(decoded_image_tv)
buf = io.BytesIO()
decoded_image_pil.save(buf, format="JPEG", quality=75)
except Exception:
# If PIL can't encode the image we can't either
return

encoded_jpeg_cuda_tv = encode_fn(decoded_image_tv.cuda().contiguous(), quality=75)
decoded_jpeg_cuda_tv = decode_jpeg(encoded_jpeg_cuda_tv.cpu())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of using PIL for that, we can just directly check whether the image is CMYK, similar to what we're doing here:

vision/test/test_image.py

Lines 425 to 426 in 2ae6a6d

if "cmyk" in img_path:
pytest.xfail("Decoding a CMYK jpeg isn't supported")

Suggested change
# next we check to see if the image is encodeable by PIL and
# if it isn't we don't include it in the test
try:
decoded_image_pil = F.to_pil_image(decoded_image_tv)
buf = io.BytesIO()
decoded_image_pil.save(buf, format="JPEG", quality=75)
except Exception:
# If PIL can't encode the image we can't either
return
encoded_jpeg_cuda_tv = encode_fn(decoded_image_tv.cuda().contiguous(), quality=75)
decoded_jpeg_cuda_tv = decode_jpeg(encoded_jpeg_cuda_tv.cpu())
if "cmyk" in img_path:
pytest.xfail("Encoding a CMYK jpeg isn't supported")

create_status);
}
});
std::call_once(::nvjpeg_handle_creation_flag, nvjpeg_init);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Noob Q: IIUC ::something means the name resolution starts at the global namespace rather than at the local namespace. But why is it needed here? There cannot be another nvjpeg_handle_creation_flag, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not strictly needed here but it makes it extra explicit that this flag is in the global namespace

cudaStream_t stream = at::cuda::getCurrentCUDAStream(device.index());

// Create global nvJPEG handle
std::call_once(::nvjpeg_handle_creation_flag, nvjpeg_init);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same Q as above about the need for :: here?

Comment on lines 166 to 170
// For some reason I couldn't get nvjpegEncodeImage to work for grayscale
// images but nvjpegEncodeYUV seems to work fine. Annoyingly however,
// nvjpegEncodeYUV requires the source image plane pitch to be divisible by 8
// so we must pad the image if needed:
// https://github.com/NVIDIA/cuda-samples/issues/23#issuecomment-559283013
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for these details. The problem with the padding is that the resulting encoded image is incorrect: it has undesirable padded values, and its shape is also different from the original image. E.g. this simple test would fail:

orig = torch.randint(0, 256, size=(1, 99, 99), dtype=torch.uint8).cuda()
decoded = decode_jpeg(encode_jpeg(orig).cpu())
assert decoded.shape == orig.shape  # Fails: torch.Size([1, 99, 104]) != torch.Size([1, 99, 99])

It's best not go with that solution. We can try to find a way to support grayscale images but in the meantime, we can just raise a loud TORCH_CHECK error if the image is grayscale.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I'll remove for now.

Comment on lines 77 to 81
// Due to the required input format of nvjpeg tensors must be contiguous in
// memory. We could do the conversion to contiguous here but that comes with
// a performance penalty which will be transparent to the user. Better to
// make this explicit and push the conversion to user code.
TORCH_CHECK(
image.is_contiguous(),
"All input tensors must be contiguous. Call tensor.contiguous() before calling this function.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand the intent, but pytorch ops tend to do that kind of implicit conversion on behalf of the user when there's no other alternative (example:

auto x1_t = dets.select(1, 0).contiguous();
auto y1_t = dets.select(1, 1).contiguous();
auto x2_t = dets.select(1, 2).contiguous();
auto y2_t = dets.select(1, 3).contiguous();
).

One way to look at it is, if the user has non-contiguous image, there are only 3 scenarios:

    1. a loud error
    1. an implicit conversion (by us)
    1. an explicit conversion (by the user)
  1. and 3. result in the exact same perf, so I'm not sure we're saving them from anything by raisin a loud error. If anything this might create a sub-optimal user-experience.

So here I'd recommend to just do the conversion on behalf of the user.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense.

"The number of channels should be 1 or 3, got: ", image.size(0));
}

torch::Device device = images[0].device();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we assert that all images are on the same device as well?

Right now, if we have [img0, img1] respectively on cuda:0 and cuda:1, the encoded output for img1 will be allocated on cuda:0 which is probably not desirable?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added some validation and a test case.

Comment on lines 148 to 149
const cudaStream_t stream,
const torch::Device& device,
const nvjpegEncoderState_t nv_enc_state,
const nvjpegEncoderParams_t nv_enc_params) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we pass all these as const refs?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes!

Comment on lines 156 to 155
nvjpegStatus_t samplingSetResult = nvjpegEncoderParamsSetSamplingFactors(
nv_enc_params, channels == 1 ? NVJPEG_CSS_GRAY : NVJPEG_CSS_444, stream);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious about this call as I don't see it in https://docs.nvidia.com/cuda/nvjpeg/index.html#jpeg-encoding-example, why is it needed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is to handle grayscale images, but won't be needed, now that grayscale is out of scope (for now)

Comment on lines 72 to 74
input: torch.Tensor,
mode: ImageReadMode = ImageReadMode.UNCHANGED,
apply_exif_orientation: bool = False,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: best to avoid that kind of unrelated formatting changes (I assume those come from the internal linter?) as they distract from the review and they also add noise to git blame.

You can keep them here as those aren't too disruptive, just letting you know for future PRs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it - yea, I think that was the linter.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants