11import os
22import os .path
33from typing import Any , Callable , cast , Dict , List , Optional , Tuple
4+ from typing import Union
45
56from PIL import Image
67
78from .vision import VisionDataset
89
910
10- def has_file_allowed_extension (filename : str , extensions : Tuple [str , ...]) -> bool :
11+ def has_file_allowed_extension (filename : str , extensions : Union [ str , Tuple [str , ...] ]) -> bool :
1112 """Checks if a file is an allowed extension.
1213
1314 Args:
@@ -17,7 +18,7 @@ def has_file_allowed_extension(filename: str, extensions: Tuple[str, ...]) -> bo
1718 Returns:
1819 bool: True if the filename ends with one of given extensions
1920 """
20- return filename .lower ().endswith (extensions )
21+ return filename .lower ().endswith (extensions if isinstance ( extensions , str ) else tuple ( extensions ) )
2122
2223
2324def is_image_file (filename : str ) -> bool :
@@ -48,7 +49,7 @@ def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]:
4849def make_dataset (
4950 directory : str ,
5051 class_to_idx : Optional [Dict [str , int ]] = None ,
51- extensions : Optional [Tuple [str , ...]] = None ,
52+ extensions : Optional [Union [ str , Tuple [str , ...] ]] = None ,
5253 is_valid_file : Optional [Callable [[str ], bool ]] = None ,
5354) -> List [Tuple [str , int ]]:
5455 """Generates a list of samples of a form (path_to_sample, class).
@@ -73,7 +74,7 @@ def make_dataset(
7374 if extensions is not None :
7475
7576 def is_valid_file (x : str ) -> bool :
76- return has_file_allowed_extension (x , cast ( Tuple [ str , ...], extensions ))
77+ return has_file_allowed_extension (x , extensions ) # type: ignore[arg-type]
7778
7879 is_valid_file = cast (Callable [[str ], bool ], is_valid_file )
7980
@@ -98,7 +99,7 @@ def is_valid_file(x: str) -> bool:
9899 if empty_classes :
99100 msg = f"Found no valid file for the classes { ', ' .join (sorted (empty_classes ))} . "
100101 if extensions is not None :
101- msg += f"Supported extensions are: { ', ' .join (extensions )} "
102+ msg += f"Supported extensions are: { extensions if isinstance ( extensions , str ) else ', ' .join (extensions )} "
102103 raise FileNotFoundError (msg )
103104
104105 return instances
0 commit comments