Skip to content

Add support for image_read to take some bytes as an input #8020

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
tchaton opened this issue Oct 5, 2023 · 4 comments
Closed

Add support for image_read to take some bytes as an input #8020

tchaton opened this issue Oct 5, 2023 · 4 comments

Comments

@tchaton
Copy link

tchaton commented Oct 5, 2023

🚀 The feature

Hey there,

For optimisation reasons, I am reading the bytes for a given JPEG image.

I would like to be able to deserialize it using torchvision by providing the bytes directly.

Best,
T.C

Motivation, pitch

I am encoding several images into a single shard for streaming purposes. Therefore, I can't provide a path to image_read.

Alternatives

I am using PIL right now and was hoping for speed gains using torchvision

Additional context

No response

@NicolasHug
Copy link
Member

Hi @tchaton , I think decode_image is what you're looking for. Feel free to re-open if not

@tchaton
Copy link
Author

tchaton commented Oct 7, 2023

Hey @NicolasHug.

I looked into decode_image. This is great in term of raw speed (17 times faster than using PIL).
My images are grouped into one large binary file and I want to read them as fast as possible.
The first option with torch.frombuffer looks like a great solution. Thanks !

Would decode_image work on GPU ? Let me try it ;)

import numpy as np
import torch
from torchvision.io import decode_image
from torchvision.transforms import ToTensor
from time import time
from io import BytesIO
from PIL import Image


t0 = time()
with open("n02110063_14457.JPEG", "rb") as f:
    data = f.read()
array = torch.frombuffer(data, dtype=torch.uint8)
array_torvision = decode_image(array)
print(time() - t0)

t0 = time()
array = np.fromfile("n02110063_14457.JPEG", dtype=np.uint8)
array_torvision = decode_image(torch.from_numpy(array))
print(time() - t0)


t0 = time()
with open("n02110063_14457.JPEG", "rb") as f:
    data = f.read()

inp = BytesIO(data)
array_jpeg = ToTensor()(Image.open(inp))
array_jpeg = (array_jpeg * 255.).to(torch.uint8)
print(time() - t0)

assert torch.equal(array_torvision, array_jpeg)
0.004538536071777344
0.004034280776977539
0.07298660278320312

@tchaton
Copy link
Author

tchaton commented Oct 7, 2023

Hey @NicolasHug, I tried using GPU decoding and this leads to a segmentation fault.

t0 = time()
with open("n02110063_14457.JPEG", "rb") as f:
    data = f.read()
array = torch.frombuffer(data, dtype=torch.uint8)
array_torvision = decode_jpeg(array.to("cuda"))
print(time() - t0)

This is odd because I can see cuda path in decode_jpeg:

    if device.type == "cuda":
        output = torch.ops.image.decode_jpeg_cuda(input, mode.value, device)
    else:
        output = torch.ops.image.decode_jpeg(input, mode.value)

Ok, it seems this worked this way:

decode_jpeg(array, device="cuda")

But took 0.937 seconds instead of 0.015 seconds where data is decoded on CPU and moved to GPU after.

@NicolasHug
Copy link
Member

yeaahh... The gpu decoder is still Beta and has a few rough edge-cases (see also #4378 where it seems to be leaking depending on the CUDA version). Perhaps your best bet is to avoid it for now, or maybe look at https://github.com/itsliupeng/torchnvjpeg

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants