Skip to content
16 changes: 16 additions & 0 deletions src/transformers/image_processing_utils_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,12 +256,16 @@ class BaseImageProcessorFast(BaseImageProcessor):
model_input_names = ["pixel_values"]
valid_init_kwargs = DefaultFastImageProcessorInitKwargs
valid_preprocess_kwargs = DefaultFastImageProcessorPreprocessKwargs
# Child classes should try to support the base processing methods as much as possible.
# If they can't, the corresponding unused kwargs should be added to this list.
unused_kwargs = []
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
unused_kwargs = []
unused_kwargs = None

list default are the worst!


def __init__(
self,
**kwargs: Unpack[DefaultFastImageProcessorInitKwargs],
) -> None:
super().__init__(**kwargs)
kwargs = self.filter_out_unused_kwargs(kwargs)
size = kwargs.pop("size", self.size)
self.size = (
get_size_dict(size=size, default_to_square=kwargs.pop("default_to_square", self.default_to_square))
Expand Down Expand Up @@ -429,6 +433,16 @@ def convert_to_rgb(
"""
return convert_to_rgb(image)

def filter_out_unused_kwargs(self, kwargs: dict):
"""
Filter out the unused kwargs from the kwargs dictionary.
"""
for kwarg_name in self.unused_kwargs:
if kwarg_name in kwargs:
logger.warning_once(f"This processor does not use the `{kwarg_name}` parameter. It will be ignored.")
kwargs.pop(kwarg_name)
return kwargs

def _prepare_images_structure(
self,
images: ImageInput,
Expand Down Expand Up @@ -555,6 +569,7 @@ def preprocess(
images: ImageInput,
**kwargs: Unpack[DefaultFastImageProcessorPreprocessKwargs],
) -> BatchFeature:
kwargs = self.filter_out_unused_kwargs(kwargs)
validate_kwargs(
captured_kwargs=kwargs.keys(), valid_processor_keys=self.valid_preprocess_kwargs.__annotations__.keys()
)
Expand Down Expand Up @@ -627,6 +642,7 @@ def _preprocess(
image_mean: Optional[Union[float, List[float]]],
image_std: Optional[Union[float, List[float]]],
return_tensors: Optional[Union[str, TensorType]],
**kwargs,
) -> BatchFeature:
# Group images by size for batched resizing
grouped_images, grouped_images_index = group_images_by_shape(images)
Expand Down
237 changes: 85 additions & 152 deletions src/transformers/models/siglip2/image_processing_siglip2_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,41 @@
import torch

from ...image_processing_utils import BatchFeature
from ...image_processing_utils_fast import BaseImageProcessorFast
from ...image_processing_utils_fast import (
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS,
BaseImageProcessorFast,
DefaultFastImageProcessorInitKwargs,
DefaultFastImageProcessorPreprocessKwargs,
SizeDict,
)
from ...image_utils import (
ChannelDimension,
ImageInput,
PILImageResampling,
SizeDict,
TensorType,
)
from ...processing_utils import Unpack
from ...utils import (
filter_out_non_signature_kwargs,
TensorType,
add_start_docstrings,
is_torch_available,
is_torchvision_available,
is_torchvision_v2_available,
logging,
)


if is_torch_available():
import torch

if is_torchvision_available():
if is_torchvision_v2_available():
from torchvision.transforms.v2 import functional as F
else:
from torchvision.transforms import functional as F


logger = logging.get_logger(__name__)


@lru_cache(maxsize=256)
# Copied from transformers.models.siglip2.image_processing_siglip2.get_image_size_for_max_num_patches
Expand Down Expand Up @@ -118,164 +136,79 @@ def pad_along_first_dim(
return tensor, mask


class Siglip2ImageProcessorFast(BaseImageProcessorFast):
r"""
Constructs a fast SigLIP2 image processor.
class Siglip2FastImageProcessorInitKwargs(DefaultFastImageProcessorInitKwargs):
patch_size: Optional[int]
max_num_patches: Optional[int]

Args:
do_resize (`bool`, *optional*, defaults to `True`):
Whether to resize the image's dimensions to fit `max_num_patches` according to given `patch_size`.
Can be overridden by `do_resize` in the `preprocess` method.
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method.
do_rescale (`bool`, *optional*, defaults to `True`):
Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in
the `preprocess` method.
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess`
method.
do_normalize (`bool`, *optional*, defaults to `True`):
Whether to normalize the image by the specified mean and standard deviation. Can be overridden by
`do_normalize` in the `preprocess` method.
image_mean (`float` or `List[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
image_std (`float` or `List[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
Can be overridden by the `image_std` parameter in the `preprocess` method.
do_convert_rgb (`bool`, *optional*, defaults to `True`):
Whether to convert the image to RGB.

class Siglip2FastImageProcessorPreprocessKwargs(DefaultFastImageProcessorPreprocessKwargs):
patch_size: Optional[int]
max_num_patches: Optional[int]


@add_start_docstrings(
r"Constructs a fast Siglip2 image processor.",
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
"""
patch_size (`int`, *optional*, defaults to 16):
The size (resolution) of each patch the image will be split to.
max_num_patches (`int`, *optional*, defaults to 256):
The image will be resized to have at most this number of patches,
and then padded in "patch" dimension to match this number exactly.
"""

def __init__(
self,
do_resize: bool = True,
resample: PILImageResampling = PILImageResampling.BILINEAR,
do_rescale: bool = True,
rescale_factor: float = 1 / 255,
do_normalize: bool = True,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
do_convert_rgb: Optional[bool] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we keep do_convert_rgb?

patch_size: int = 16,
max_num_patches: int = 256,
**kwargs,
):
""",
)
class Siglip2ImageProcessorFast(BaseImageProcessorFast):
resample = PILImageResampling.BILINEAR
image_mean = [0.5, 0.5, 0.5]
image_std = [0.5, 0.5, 0.5]
do_resize = True
do_rescale = True
do_normalize = True
patch_size = 16
max_num_patches = 256
valid_init_kwargs = Siglip2FastImageProcessorInitKwargs
valid_preprocess_kwargs = Siglip2FastImageProcessorPreprocessKwargs
unused_kwargs = ["size", "do_center_crop", "crop_size"]

def __init__(self, **kwargs: Unpack[Siglip2FastImageProcessorInitKwargs]):
super().__init__(**kwargs)

image_mean = image_mean if image_mean is not None else [0.5, 0.5, 0.5]
image_std = image_std if image_std is not None else [0.5, 0.5, 0.5]

self.do_resize = do_resize
self.resample = resample
self.do_rescale = do_rescale
self.rescale_factor = rescale_factor
self.do_normalize = do_normalize
self.image_mean = image_mean
self.image_std = image_std
self.do_convert_rgb = do_convert_rgb
self.patch_size = patch_size
self.max_num_patches = max_num_patches

@filter_out_non_signature_kwargs()
@lru_cache(maxsize=10)
def _prepare_process_arguments(self, **kwargs) -> tuple:
# Remove do_resize from kwargs to not raise an error as size is None
kwargs.pop("do_resize", None)
return super()._prepare_process_arguments(**kwargs)

@add_start_docstrings(
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS,
"""
patch_size (`int`, *optional*, defaults to `self.patch_size`):
The size (resolution) of each patch the image will be split to.
max_num_patches (`int`, *optional*, defaults to `self.max_num_patches`):
The image will be resized to have at most this number of patches,
and then padded in "patch" dimension to match this number exactly.
""",
)
def preprocess(
self,
images: ImageInput,
do_resize: Optional[bool] = None,
resample: Optional[PILImageResampling] = None,
do_rescale: Optional[bool] = None,
rescale_factor: Optional[float] = None,
do_normalize: Optional[bool] = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
do_convert_rgb: Optional[bool] = None,
patch_size: Optional[int] = None,
max_num_patches: Optional[int] = None,
device: Union["torch.device", str] = "cpu",
self, images: ImageInput, **kwargs: Unpack[Siglip2FastImageProcessorPreprocessKwargs]
) -> BatchFeature:
"""
Preprocess an image or batch of images.

Args:
images (`ImageInput`):
Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
Whether to resize the image.
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
Size of the image after resizing.
resample (`int`, *optional*, defaults to `self.resample`):
Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
has an effect if `do_resize` is set to `True`.
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
Whether to rescale the image.
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
Whether to normalize the image.
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
`True`.
return_tensors (`str` or `TensorType`, *optional*):
The type of tensors to return. Can be one of:
- Unset: Return a list of `np.ndarray`.
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
Whether to convert the image to RGB.
patch_size (`int`, *optional*, defaults to `self.patch_size`):
Patch size for processing, same as the patch size used in the model.
max_num_patches (`int`, *optional*, defaults to `self.max_num_patches`):
Maximum number of patches per image, the image will be resized to have at most this number of patches.
"""
do_resize = do_resize if do_resize is not None else self.do_resize
resample = resample if resample is not None else self.resample
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
image_mean = image_mean if image_mean is not None else self.image_mean
image_std = image_std if image_std is not None else self.image_std
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
patch_size = patch_size if patch_size is not None else self.patch_size
max_num_patches = max_num_patches if max_num_patches is not None else self.max_num_patches

image_mean = tuple(image_mean) if isinstance(image_mean, list) else image_mean
image_std = tuple(image_std) if isinstance(image_std, list) else image_std

image_mean, image_std, interpolation = self._prepare_process_arguments(
do_normalize=do_normalize,
do_rescale=do_rescale,
rescale_factor=rescale_factor,
image_mean=image_mean,
image_std=image_std,
resample=resample,
)

images = self._prepare_input_images(
images=images,
do_convert_rgb=do_convert_rgb,
input_data_format=input_data_format,
device=device,
)
return super().preprocess(images, **kwargs)

def _preprocess(
self,
images: List["torch.Tensor"],
do_resize: bool,
patch_size: int,
max_num_patches: int,
interpolation: Optional["F.InterpolationMode"],
do_rescale: bool,
rescale_factor: float,
do_normalize: bool,
image_mean: Optional[Union[float, List[float]]],
image_std: Optional[Union[float, List[float]]],
return_tensors: Optional[Union[str, TensorType]],
**kwargs,
) -> BatchFeature:
pixel_masks = []
pixel_values = []
spatial_shapes = []
Expand Down