Skip to content

feat: expose loader argument in Country211 and EuroSAT. #8922

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Feb 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 18 additions & 5 deletions test/datasets_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,6 +611,7 @@ class ImageDatasetTestCase(DatasetTestCase):
"""

FEATURE_TYPES = (PIL.Image.Image, int)
SUPPORT_TV_IMAGE_DECODE: bool = False

@contextlib.contextmanager
def create_dataset(
Expand All @@ -632,22 +633,34 @@ def create_dataset(
# This problem only occurs during testing since some tests, e.g. DatasetTestCase.test_feature_types open an
# image, but never use the underlying data. During normal operation it is reasonable to assume that the
# user wants to work with the image he just opened rather than deleting the underlying file.
with self._force_load_images():
with self._force_load_images(loader=(config or {}).get("loader", None)):
yield dataset, info

@contextlib.contextmanager
def _force_load_images(self):
open = PIL.Image.open
def _force_load_images(self, loader: Optional[Callable[[str], Any]] = None):
open = loader or PIL.Image.open

def new(fp, *args, **kwargs):
image = open(fp, *args, **kwargs)
if isinstance(fp, (str, pathlib.Path)):
if isinstance(fp, (str, pathlib.Path)) and isinstance(image, PIL.Image.Image):
image.load()
return image

with unittest.mock.patch("PIL.Image.open", new=new):
with unittest.mock.patch(open.__module__ + "." + open.__qualname__, new=new):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Damn, I hope we never have to change / fix this ever lol.

Thank you so much for pushing through the test @GdoongMathew

yield

def test_tv_decode_image_support(self):
if not self.SUPPORT_TV_IMAGE_DECODE:
pytest.skip(f"{self.DATASET_CLASS.__name__} does not support torchvision.io.decode_image.")

with self.create_dataset(
config=dict(
loader=torchvision.io.decode_image,
)
) as (dataset, _):
image = dataset[0][0]
assert isinstance(image, torch.Tensor)


class VideoDatasetTestCase(DatasetTestCase):
"""Abstract base class for video dataset testcases.
Expand Down
5 changes: 5 additions & 0 deletions test/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,8 @@ class ImageNetTestCase(datasets_utils.ImageDatasetTestCase):
REQUIRED_PACKAGES = ("scipy",)
ADDITIONAL_CONFIGS = combinations_grid(split=("train", "val"))

SUPPORT_TV_IMAGE_DECODE = True

def inject_fake_data(self, tmpdir, config):
tmpdir = pathlib.Path(tmpdir)

Expand Down Expand Up @@ -2308,6 +2310,7 @@ def inject_fake_data(self, tmpdir, config):
class EuroSATTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.EuroSAT
FEATURE_TYPES = (PIL.Image.Image, int)
SUPPORT_TV_IMAGE_DECODE = True

def inject_fake_data(self, tmpdir, config):
data_folder = os.path.join(tmpdir, "eurosat", "2750")
Expand Down Expand Up @@ -2749,6 +2752,8 @@ class Country211TestCase(datasets_utils.ImageDatasetTestCase):

ADDITIONAL_CONFIGS = combinations_grid(split=("train", "valid", "test"))

SUPPORT_TV_IMAGE_DECODE = True

def inject_fake_data(self, tmpdir: str, config):
split_folder = pathlib.Path(tmpdir) / "country211" / config["split"]
split_folder.mkdir(parents=True, exist_ok=True)
Expand Down
15 changes: 12 additions & 3 deletions torchvision/datasets/country211.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from pathlib import Path
from typing import Callable, Optional, Union
from typing import Any, Callable, Optional, Union

from .folder import ImageFolder
from .folder import default_loader, ImageFolder
from .utils import download_and_extract_archive, verify_str_arg


Expand All @@ -21,6 +21,9 @@ class Country211(ImageFolder):
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
download (bool, optional): If True, downloads the dataset from the internet and puts it into
``root/country211/``. If dataset is already downloaded, it is not downloaded again.
loader (callable, optional): A function to load an image given its path.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR @GdoongMathew ,

Here and below, we should specify that the default is to use PIL, but that we encourage users to try to use decode_image()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added. I also updated the one in ImageNet dataset as well.

By default, it uses PIL as its image loader, but users could also pass in
``torchvision.io.decode_image`` for decoding image data into tensors directly.
"""

_URL = "https://openaipublic.azureedge.net/clip/data/country211.tgz"
Expand All @@ -33,6 +36,7 @@ def __init__(
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
loader: Callable[[str], Any] = default_loader,
) -> None:
self._split = verify_str_arg(split, "split", ("train", "valid", "test"))

Expand All @@ -46,7 +50,12 @@ def __init__(
if not self._check_exists():
raise RuntimeError("Dataset not found. You can use download=True to download it")

super().__init__(str(self._base_folder / self._split), transform=transform, target_transform=target_transform)
super().__init__(
str(self._base_folder / self._split),
transform=transform,
target_transform=target_transform,
loader=loader,
)
self.root = str(root)

def _check_exists(self) -> bool:
Expand Down
15 changes: 12 additions & 3 deletions torchvision/datasets/eurosat.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import os
from pathlib import Path
from typing import Callable, Optional, Union
from typing import Any, Callable, Optional, Union

from .folder import ImageFolder
from .folder import default_loader, ImageFolder
from .utils import download_and_extract_archive


Expand All @@ -21,6 +21,9 @@ class EuroSAT(ImageFolder):
download (bool, optional): If True, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again. Default is False.
loader (callable, optional): A function to load an image given its path.
By default, it uses PIL as its image loader, but users could also pass in
``torchvision.io.decode_image`` for decoding image data into tensors directly.
"""

def __init__(
Expand All @@ -29,6 +32,7 @@ def __init__(
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
loader: Callable[[str], Any] = default_loader,
) -> None:
self.root = os.path.expanduser(root)
self._base_folder = os.path.join(self.root, "eurosat")
Expand All @@ -40,7 +44,12 @@ def __init__(
if not self._check_exists():
raise RuntimeError("Dataset not found. You can use download=True to download it")

super().__init__(self._data_folder, transform=transform, target_transform=target_transform)
super().__init__(
self._data_folder,
transform=transform,
target_transform=target_transform,
loader=loader,
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do you feel about writing a basic test for this? Hopefully this can fit in <10 lines of code and we can re-use the same test across most datasets?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added. Hopefully, for the rest of the test classes, one could update their SUPPORT_TV_IMAGE_DECODE attribute whenever they start to support loader argument.

self.root = os.path.expanduser(root)

def __len__(self) -> int:
Expand Down
2 changes: 2 additions & 0 deletions torchvision/datasets/imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ class ImageNet(ImageFolder):
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
loader (callable, optional): A function to load an image given its path.
By default, it uses PIL as its image loader, but users could also pass in
``torchvision.io.decode_image`` for decoding image data into tensors directly.

Attributes:
classes (list): List of the class name tuples.
Expand Down