|
| 1 | +import concurrent.futures |
1 | 2 | import glob
|
2 | 3 | import io
|
3 | 4 | import os
|
|
10 | 11 | import requests
|
11 | 12 | import torch
|
12 | 13 | import torchvision.transforms.functional as F
|
13 |
| -from common_utils import assert_equal, IN_OSS_CI, needs_cuda |
| 14 | +from common_utils import assert_equal, cpu_and_cuda, IN_OSS_CI, needs_cuda |
14 | 15 | from PIL import __version__ as PILLOW_VERSION, Image, ImageOps, ImageSequence
|
15 | 16 | from torchvision.io.image import (
|
16 | 17 | _read_png_16,
|
@@ -508,6 +509,200 @@ def test_encode_jpeg(img_path, scripted):
|
508 | 509 | assert_equal(encoded_jpeg_torch, encoded_jpeg_pil)
|
509 | 510 |
|
510 | 511 |
|
| 512 | +@needs_cuda |
| 513 | +def test_encode_jpeg_cuda_device_param(): |
| 514 | + path = next(path for path in get_images(IMAGE_ROOT, ".jpg") if "cmyk" not in path) |
| 515 | + |
| 516 | + data = read_image(path) |
| 517 | + |
| 518 | + current_device = torch.cuda.current_device() |
| 519 | + current_stream = torch.cuda.current_stream() |
| 520 | + num_devices = torch.cuda.device_count() |
| 521 | + devices = ["cuda", torch.device("cuda")] + [torch.device(f"cuda:{i}") for i in range(num_devices)] |
| 522 | + results = [] |
| 523 | + for device in devices: |
| 524 | + print(f"python: device: {device}") |
| 525 | + results.append(encode_jpeg(data.to(device=device))) |
| 526 | + assert len(results) == len(devices) |
| 527 | + for result in results: |
| 528 | + assert torch.all(result.cpu() == results[0].cpu()) |
| 529 | + |
| 530 | + assert current_device == torch.cuda.current_device() |
| 531 | + assert current_stream == torch.cuda.current_stream() |
| 532 | + |
| 533 | + |
| 534 | +@needs_cuda |
| 535 | +@pytest.mark.parametrize( |
| 536 | + "img_path", |
| 537 | + [pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(IMAGE_ROOT, ".jpg")], |
| 538 | +) |
| 539 | +@pytest.mark.parametrize("scripted", (False, True)) |
| 540 | +@pytest.mark.parametrize("contiguous", (False, True)) |
| 541 | +def test_encode_jpeg_cuda(img_path, scripted, contiguous): |
| 542 | + decoded_image_tv = read_image(img_path) |
| 543 | + encode_fn = torch.jit.script(encode_jpeg) if scripted else encode_jpeg |
| 544 | + |
| 545 | + if "cmyk" in img_path: |
| 546 | + pytest.xfail("Encoding a CMYK jpeg isn't supported") |
| 547 | + if decoded_image_tv.shape[0] == 1: |
| 548 | + pytest.xfail("Decoding a grayscale jpeg isn't supported") |
| 549 | + # For more detail as to why check out: https://github.com/NVIDIA/cuda-samples/issues/23#issuecomment-559283013 |
| 550 | + if contiguous: |
| 551 | + decoded_image_tv = decoded_image_tv[None].contiguous(memory_format=torch.contiguous_format)[0] |
| 552 | + else: |
| 553 | + decoded_image_tv = decoded_image_tv[None].contiguous(memory_format=torch.channels_last)[0] |
| 554 | + encoded_jpeg_cuda_tv = encode_fn(decoded_image_tv.cuda(), quality=75) |
| 555 | + decoded_jpeg_cuda_tv = decode_jpeg(encoded_jpeg_cuda_tv.cpu()) |
| 556 | + |
| 557 | + # the actual encoded bytestreams from libnvjpeg and libjpeg-turbo differ for the same quality |
| 558 | + # instead, we re-decode the encoded image and compare to the original |
| 559 | + abs_mean_diff = (decoded_jpeg_cuda_tv.float() - decoded_image_tv.float()).abs().mean().item() |
| 560 | + assert abs_mean_diff < 3 |
| 561 | + |
| 562 | + |
| 563 | +@pytest.mark.parametrize("device", cpu_and_cuda()) |
| 564 | +@pytest.mark.parametrize("scripted", (True, False)) |
| 565 | +@pytest.mark.parametrize("contiguous", (True, False)) |
| 566 | +def test_encode_jpegs_batch(scripted, contiguous, device): |
| 567 | + if device == "cpu" and IS_MACOS: |
| 568 | + pytest.skip("https://github.com/pytorch/vision/issues/8031") |
| 569 | + decoded_images_tv = [] |
| 570 | + for jpeg_path in get_images(IMAGE_ROOT, ".jpg"): |
| 571 | + if "cmyk" in jpeg_path: |
| 572 | + continue |
| 573 | + decoded_image = read_image(jpeg_path) |
| 574 | + if decoded_image.shape[0] == 1: |
| 575 | + continue |
| 576 | + if contiguous: |
| 577 | + decoded_image = decoded_image[None].contiguous(memory_format=torch.contiguous_format)[0] |
| 578 | + else: |
| 579 | + decoded_image = decoded_image[None].contiguous(memory_format=torch.channels_last)[0] |
| 580 | + decoded_images_tv.append(decoded_image) |
| 581 | + |
| 582 | + encode_fn = torch.jit.script(encode_jpeg) if scripted else encode_jpeg |
| 583 | + |
| 584 | + decoded_images_tv_device = [img.to(device=device) for img in decoded_images_tv] |
| 585 | + encoded_jpegs_tv_device = encode_fn(decoded_images_tv_device, quality=75) |
| 586 | + encoded_jpegs_tv_device = [decode_jpeg(img.cpu()) for img in encoded_jpegs_tv_device] |
| 587 | + |
| 588 | + for original, encoded_decoded in zip(decoded_images_tv, encoded_jpegs_tv_device): |
| 589 | + c, h, w = original.shape |
| 590 | + abs_mean_diff = (original.float() - encoded_decoded.float()).abs().mean().item() |
| 591 | + assert abs_mean_diff < 3 |
| 592 | + |
| 593 | + # test multithreaded decoding |
| 594 | + # in the current version we prevent this by using a lock but we still want to test it |
| 595 | + num_workers = 10 |
| 596 | + with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor: |
| 597 | + futures = [executor.submit(encode_fn, decoded_images_tv_device) for _ in range(num_workers)] |
| 598 | + encoded_images_threaded = [future.result() for future in futures] |
| 599 | + assert len(encoded_images_threaded) == num_workers |
| 600 | + for encoded_images in encoded_images_threaded: |
| 601 | + assert len(decoded_images_tv_device) == len(encoded_images) |
| 602 | + for i, (encoded_image_cuda, decoded_image_tv) in enumerate(zip(encoded_images, decoded_images_tv_device)): |
| 603 | + # make sure all the threads produce identical outputs |
| 604 | + assert torch.all(encoded_image_cuda == encoded_images_threaded[0][i]) |
| 605 | + |
| 606 | + # make sure the outputs are identical or close enough to baseline |
| 607 | + decoded_cuda_encoded_image = decode_jpeg(encoded_image_cuda.cpu()) |
| 608 | + assert decoded_cuda_encoded_image.shape == decoded_image_tv.shape |
| 609 | + assert decoded_cuda_encoded_image.dtype == decoded_image_tv.dtype |
| 610 | + assert (decoded_cuda_encoded_image.cpu().float() - decoded_image_tv.cpu().float()).abs().mean() < 3 |
| 611 | + |
| 612 | + |
| 613 | +@needs_cuda |
| 614 | +def test_single_encode_jpeg_cuda_errors(): |
| 615 | + with pytest.raises(RuntimeError, match="Input tensor dtype should be uint8"): |
| 616 | + encode_jpeg(torch.empty((3, 100, 100), dtype=torch.float32, device="cuda")) |
| 617 | + |
| 618 | + with pytest.raises(RuntimeError, match="The number of channels should be 3, got: 5"): |
| 619 | + encode_jpeg(torch.empty((5, 100, 100), dtype=torch.uint8, device="cuda")) |
| 620 | + |
| 621 | + with pytest.raises(RuntimeError, match="The number of channels should be 3, got: 1"): |
| 622 | + encode_jpeg(torch.empty((1, 100, 100), dtype=torch.uint8, device="cuda")) |
| 623 | + |
| 624 | + with pytest.raises(RuntimeError, match="Input data should be a 3-dimensional tensor"): |
| 625 | + encode_jpeg(torch.empty((1, 3, 100, 100), dtype=torch.uint8, device="cuda")) |
| 626 | + |
| 627 | + with pytest.raises(RuntimeError, match="Input data should be a 3-dimensional tensor"): |
| 628 | + encode_jpeg(torch.empty((100, 100), dtype=torch.uint8, device="cuda")) |
| 629 | + |
| 630 | + |
| 631 | +@needs_cuda |
| 632 | +def test_batch_encode_jpegs_cuda_errors(): |
| 633 | + with pytest.raises(RuntimeError, match="Input tensor dtype should be uint8"): |
| 634 | + encode_jpeg( |
| 635 | + [ |
| 636 | + torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda"), |
| 637 | + torch.empty((3, 100, 100), dtype=torch.float32, device="cuda"), |
| 638 | + ] |
| 639 | + ) |
| 640 | + |
| 641 | + with pytest.raises(RuntimeError, match="The number of channels should be 3, got: 5"): |
| 642 | + encode_jpeg( |
| 643 | + [ |
| 644 | + torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda"), |
| 645 | + torch.empty((5, 100, 100), dtype=torch.uint8, device="cuda"), |
| 646 | + ] |
| 647 | + ) |
| 648 | + |
| 649 | + with pytest.raises(RuntimeError, match="The number of channels should be 3, got: 1"): |
| 650 | + encode_jpeg( |
| 651 | + [ |
| 652 | + torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda"), |
| 653 | + torch.empty((1, 100, 100), dtype=torch.uint8, device="cuda"), |
| 654 | + ] |
| 655 | + ) |
| 656 | + |
| 657 | + with pytest.raises(RuntimeError, match="Input data should be a 3-dimensional tensor"): |
| 658 | + encode_jpeg( |
| 659 | + [ |
| 660 | + torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda"), |
| 661 | + torch.empty((1, 3, 100, 100), dtype=torch.uint8, device="cuda"), |
| 662 | + ] |
| 663 | + ) |
| 664 | + |
| 665 | + with pytest.raises(RuntimeError, match="Input data should be a 3-dimensional tensor"): |
| 666 | + encode_jpeg( |
| 667 | + [ |
| 668 | + torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda"), |
| 669 | + torch.empty((100, 100), dtype=torch.uint8, device="cuda"), |
| 670 | + ] |
| 671 | + ) |
| 672 | + |
| 673 | + with pytest.raises(RuntimeError, match="Input tensor should be on CPU"): |
| 674 | + encode_jpeg( |
| 675 | + [ |
| 676 | + torch.empty((3, 100, 100), dtype=torch.uint8, device="cpu"), |
| 677 | + torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda"), |
| 678 | + ] |
| 679 | + ) |
| 680 | + |
| 681 | + with pytest.raises( |
| 682 | + RuntimeError, match="All input tensors must be on the same CUDA device when encoding with nvjpeg" |
| 683 | + ): |
| 684 | + encode_jpeg( |
| 685 | + [ |
| 686 | + torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda"), |
| 687 | + torch.empty((3, 100, 100), dtype=torch.uint8, device="cpu"), |
| 688 | + ] |
| 689 | + ) |
| 690 | + |
| 691 | + if torch.cuda.device_count() >= 2: |
| 692 | + with pytest.raises( |
| 693 | + RuntimeError, match="All input tensors must be on the same CUDA device when encoding with nvjpeg" |
| 694 | + ): |
| 695 | + encode_jpeg( |
| 696 | + [ |
| 697 | + torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda:0"), |
| 698 | + torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda:1"), |
| 699 | + ] |
| 700 | + ) |
| 701 | + |
| 702 | + with pytest.raises(ValueError, match="encode_jpeg requires at least one input tensor when a list is passed"): |
| 703 | + encode_jpeg([]) |
| 704 | + |
| 705 | + |
511 | 706 | @pytest.mark.skipif(IS_MACOS, reason="https://github.com/pytorch/vision/issues/8031")
|
512 | 707 | @pytest.mark.parametrize(
|
513 | 708 | "img_path",
|
|
0 commit comments