Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/source/en/model_doc/swin2sr.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@ A demo Space for image super-resolution with SwinSR can be found [here](https://
[[autodoc]] Swin2SRImageProcessor
- preprocess

## Swin2SRImageProcessorFast

[[autodoc]] Swin2SRImageProcessorFast
- preprocess

## Swin2SRConfig

[[autodoc]] Swin2SRConfig
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/auto/image_processing_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@
("superglue", ("SuperGlueImageProcessor",)),
("swiftformer", ("ViTImageProcessor", "ViTImageProcessorFast")),
("swin", ("ViTImageProcessor", "ViTImageProcessorFast")),
("swin2sr", ("Swin2SRImageProcessor",)),
("swin2sr", ("Swin2SRImageProcessor", "Swin2SRImageProcessorFast")),
("swinv2", ("ViTImageProcessor", "ViTImageProcessorFast")),
("table-transformer", ("DetrImageProcessor",)),
("timesformer", ("VideoMAEImageProcessor",)),
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/swin2sr/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
if TYPE_CHECKING:
from .configuration_swin2sr import *
from .image_processing_swin2sr import *
from .image_processing_swin2sr_fast import *
from .modeling_swin2sr import *
else:
import sys
Expand Down
138 changes: 138 additions & 0 deletions src/transformers/models/swin2sr/image_processing_swin2sr_fast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Fast Image processor class for Swin2SR."""

from typing import List, Optional, Union

from ...image_processing_utils import (
BatchFeature,
ChannelDimension,
get_image_size,
)
from ...image_processing_utils_fast import (
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS,
BaseImageProcessorFast,
DefaultFastImageProcessorKwargs,
group_images_by_shape,
reorder_images,
)
from ...image_utils import ImageInput
from ...processing_utils import Unpack
from ...utils import (
TensorType,
add_start_docstrings,
is_torch_available,
is_torchvision_available,
is_torchvision_v2_available,
)


if is_torch_available():
import torch

if is_torchvision_available():
if is_torchvision_v2_available():
from torchvision.transforms.v2 import functional as F
else:
from torchvision.transforms import functional as F


class Swin2SRFastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
do_pad: Optional[bool]
pad_size: Optional[int]


@add_start_docstrings(
"Constructs a fast Swin2SR image processor.",
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
"""
do_pad (`bool`, *optional*, defaults to `True`):
Whether to pad the image to make the height and width divisible by `window_size`.
pad_size (`int`, *optional*, defaults to `8`):
The size of the sliding window for the local attention.
""",
)
class Swin2SRImageProcessorFast(BaseImageProcessorFast):
do_rescale = True
rescale_factor = 1 / 255
do_pad = True
pad_size = 8
valid_kwargs = Swin2SRFastImageProcessorKwargs

def __init__(self, **kwargs: Unpack[Swin2SRFastImageProcessorKwargs]):
super().__init__(**kwargs)

@add_start_docstrings(
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS,
"""
do_pad (`bool`, *optional*, defaults to `True`):
Whether to pad the image to make the height and width divisible by `window_size`.
pad_size (`int`, *optional*, defaults to `8`):
The size of the sliding window for the local attention.
""",
)
def preprocess(self, images: ImageInput, **kwargs: Unpack[Swin2SRFastImageProcessorKwargs]) -> BatchFeature:
return super().preprocess(images, **kwargs)

def pad(self, images: "torch.Tensor", size: int) -> "torch.Tensor":
"""
Pad an image to make the height and width divisible by `size`.

Args:
images (`torch.Tensor`):
Images to pad.
size (`int`):
The size to make the height and width divisible by.

Returns:
`torch.Tensor`: The padded images.
"""
height, width = get_image_size(images, ChannelDimension.FIRST)
pad_height = (height // size + 1) * size - height
pad_width = (width // size + 1) * size - width

return F.pad(
images,
(0, 0, pad_width, pad_height),
padding_mode="symmetric",
)

def _preprocess(
self,
images: List["torch.Tensor"],
do_rescale: bool,
rescale_factor: float,
do_pad: bool,
pad_size: int,
return_tensors: Optional[Union[str, TensorType]],
interpolation: Optional["F.InterpolationMode"],
**kwargs,
) -> BatchFeature:
grouped_images, grouped_images_index = group_images_by_shape(images)
processed_image_grouped = {}
for shape, stacked_images in grouped_images.items():
if do_rescale:
stacked_images = self.rescale(stacked_images, scale=rescale_factor)
if do_pad:
stacked_images = self.pad(stacked_images, size=pad_size)
processed_image_grouped[shape] = stacked_images
processed_images = reorder_images(processed_image_grouped, grouped_images_index)
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images

return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)


__all__ = ["Swin2SRImageProcessorFast"]
32 changes: 26 additions & 6 deletions tests/models/swin2sr/test_image_processing_swin2sr.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import numpy as np

from transformers.testing_utils import require_torch, require_vision
from transformers.utils import is_torch_available, is_vision_available
from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available

from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs

Expand All @@ -30,6 +30,9 @@
from PIL import Image

from transformers import Swin2SRImageProcessor

if is_torchvision_available():
from transformers import Swin2SRImageProcessorFast
from transformers.image_transforms import get_image_size


Expand Down Expand Up @@ -97,6 +100,7 @@ def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=F
@require_vision
class Swin2SRImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
image_processing_class = Swin2SRImageProcessor if is_vision_available() else None
fast_image_processing_class = Swin2SRImageProcessorFast if is_torchvision_available() else None

def setUp(self):
super().setUp()
Expand All @@ -107,11 +111,12 @@ def image_processor_dict(self):
return self.image_processor_tester.prepare_image_processor_dict()

def test_image_processor_properties(self):
image_processor = self.image_processing_class(**self.image_processor_dict)
self.assertTrue(hasattr(image_processor, "do_rescale"))
self.assertTrue(hasattr(image_processor, "rescale_factor"))
self.assertTrue(hasattr(image_processor, "do_pad"))
self.assertTrue(hasattr(image_processor, "pad_size"))
for image_processing_class in self.image_processor_list:
image_processing = image_processing_class(**self.image_processor_dict)
Comment on lines +114 to +115
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be done for all tests

Copy link
Contributor Author

@thisisiron thisisiron Apr 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I ran the following command:
RUN_SLOW=1 python -m pytest tests/models/swin2sr/test_image_processing_swin2sr.py

The log below shows the result of executing the above command.

configfile: pyproject.toml
plugins: timeout-2.3.1, hypothesis-6.124.2, anyio-4.8.0, asyncio-0.23.8, rich-0.2.0, xdist-3.6.1
asyncio: mode=strict
collected 18 items                                                                                                                                                                                                                         

tests/models/swin2sr/test_image_processing_swin2sr.py::Swin2SRImageProcessingTest::test_call_numpy PASSED                                                                                                                            [  5%]
tests/models/swin2sr/test_image_processing_swin2sr.py::Swin2SRImageProcessingTest::test_call_numpy_4_channels PASSED                                                                                                                 [ 11%]
tests/models/swin2sr/test_image_processing_swin2sr.py::Swin2SRImageProcessingTest::test_call_pil PASSED                                                                                                                              [ 16%]
tests/models/swin2sr/test_image_processing_swin2sr.py::Swin2SRImageProcessingTest::test_call_pytorch PASSED                                                                                                                          [ 22%]
tests/models/swin2sr/test_image_processing_swin2sr.py::Swin2SRImageProcessingTest::test_can_compile_fast_image_processor PASSED                                                                                                      [ 27%]
tests/models/swin2sr/test_image_processing_swin2sr.py::Swin2SRImageProcessingTest::test_cast_dtype_device PASSED                                                                                                                     [ 33%]
tests/models/swin2sr/test_image_processing_swin2sr.py::Swin2SRImageProcessingTest::test_fast_is_faster_than_slow PASSED                                                                                                              [ 38%]
tests/models/swin2sr/test_image_processing_swin2sr.py::Swin2SRImageProcessingTest::test_image_processor_from_and_save_pretrained PASSED                                                                                              [ 44%]
tests/models/swin2sr/test_image_processing_swin2sr.py::Swin2SRImageProcessingTest::test_image_processor_preprocess_arguments PASSED                                                                                                  [ 50%]
tests/models/swin2sr/test_image_processing_swin2sr.py::Swin2SRImageProcessingTest::test_image_processor_properties PASSED                                                                                                            [ 55%]
tests/models/swin2sr/test_image_processing_swin2sr.py::Swin2SRImageProcessingTest::test_image_processor_save_load_with_autoimageprocessor PASSED                                                                                     [ 61%]
tests/models/swin2sr/test_image_processing_swin2sr.py::Swin2SRImageProcessingTest::test_image_processor_to_json_file PASSED                                                                                                          [ 66%]
tests/models/swin2sr/test_image_processing_swin2sr.py::Swin2SRImageProcessingTest::test_image_processor_to_json_string PASSED                                                                                                        [ 72%]
tests/models/swin2sr/test_image_processing_swin2sr.py::Swin2SRImageProcessingTest::test_init_without_params PASSED                                                                                                                   [ 77%]
tests/models/swin2sr/test_image_processing_swin2sr.py::Swin2SRImageProcessingTest::test_save_load_fast_slow PASSED                                                                                                                   [ 83%]
tests/models/swin2sr/test_image_processing_swin2sr.py::Swin2SRImageProcessingTest::test_save_load_fast_slow_auto PASSED                                                                                                              [ 88%]
tests/models/swin2sr/test_image_processing_swin2sr.py::Swin2SRImageProcessingTest::test_slow_fast_equivalence PASSED                                                                                                                 [ 94%]
tests/models/swin2sr/test_image_processing_swin2sr.py::Swin2SRImageProcessingTest::test_slow_fast_equivalence_batched PASSED                                                                                                         [100%]

============================================================================================================ 18 passed in 6.78s ============================================================================================================

self.assertTrue(hasattr(image_processing, "do_rescale"))
self.assertTrue(hasattr(image_processing, "rescale_factor"))
self.assertTrue(hasattr(image_processing, "do_pad"))
self.assertTrue(hasattr(image_processing, "pad_size"))

def calculate_expected_size(self, image):
old_height, old_width = get_image_size(image)
Expand Down Expand Up @@ -181,3 +186,18 @@ def test_call_pytorch(self):
encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape([image_inputs[0]])
self.assertEqual(tuple(encoded_images.shape), (1, *expected_output_image_shape))

@unittest.skip(reason="No speed gain on CPU due to minimal processing.")
def test_fast_is_faster_than_slow(self):
pass

def test_slow_fast_equivalence_batched(self):
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True, torchify=True)

image_processor_slow = self.image_processing_class(**self.image_processor_dict)
image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict)

encoded_slow = image_processor_slow(image_inputs, return_tensors="pt").pixel_values
encoded_fast = image_processor_fast(image_inputs, return_tensors="pt").pixel_values

self.assertTrue(torch.allclose(encoded_slow, encoded_fast, atol=1e-1))