1
1
import os
2
2
import os .path
3
3
from typing import Any , Callable , cast , Dict , List , Optional , Tuple
4
+ from typing import Union
4
5
5
6
from PIL import Image
6
7
7
8
from .vision import VisionDataset
8
9
9
10
10
- def has_file_allowed_extension (filename : str , extensions : Tuple [str , ...]) -> bool :
11
+ def has_file_allowed_extension (filename : str , extensions : Union [ str , Tuple [str , ...] ]) -> bool :
11
12
"""Checks if a file is an allowed extension.
12
13
13
14
Args:
@@ -17,7 +18,7 @@ def has_file_allowed_extension(filename: str, extensions: Tuple[str, ...]) -> bo
17
18
Returns:
18
19
bool: True if the filename ends with one of given extensions
19
20
"""
20
- return filename .lower ().endswith (extensions )
21
+ return filename .lower ().endswith (extensions if isinstance ( extensions , str ) else tuple ( extensions ) )
21
22
22
23
23
24
def is_image_file (filename : str ) -> bool :
@@ -48,7 +49,7 @@ def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]:
48
49
def make_dataset (
49
50
directory : str ,
50
51
class_to_idx : Optional [Dict [str , int ]] = None ,
51
- extensions : Optional [Tuple [str , ...]] = None ,
52
+ extensions : Optional [Union [ str , Tuple [str , ...] ]] = None ,
52
53
is_valid_file : Optional [Callable [[str ], bool ]] = None ,
53
54
) -> List [Tuple [str , int ]]:
54
55
"""Generates a list of samples of a form (path_to_sample, class).
@@ -73,7 +74,7 @@ def make_dataset(
73
74
if extensions is not None :
74
75
75
76
def is_valid_file (x : str ) -> bool :
76
- return has_file_allowed_extension (x , cast ( Tuple [ str , ...], extensions ))
77
+ return has_file_allowed_extension (x , extensions ) # type: ignore[arg-type]
77
78
78
79
is_valid_file = cast (Callable [[str ], bool ], is_valid_file )
79
80
@@ -98,7 +99,7 @@ def is_valid_file(x: str) -> bool:
98
99
if empty_classes :
99
100
msg = f"Found no valid file for the classes { ', ' .join (sorted (empty_classes ))} . "
100
101
if extensions is not None :
101
- msg += f"Supported extensions are: { ', ' .join (extensions )} "
102
+ msg += f"Supported extensions are: { extensions if isinstance ( extensions , str ) else ', ' .join (extensions )} "
102
103
raise FileNotFoundError (msg )
103
104
104
105
return instances
0 commit comments