@@ -603,6 +603,11 @@ def _worker_init_fn(worker_id):
603603 torch .utils .data .graph_settings .apply_sharding (datapipe , num_workers , worker_id )
604604
605605
606+ lambda_fn1 = lambda x : x # noqa: E731
607+ lambda_fn2 = lambda x : x % 2 # noqa: E731
608+ lambda_fn3 = lambda x : x >= 5 # noqa: E731
609+
610+
606611class TestFunctionalIterDataPipe (TestCase ):
607612
608613 def _serialization_test_helper (self , datapipe , use_dill ):
@@ -702,30 +707,58 @@ def test_serializable(self):
702707 def test_serializable_with_dill (self ):
703708 """Only for DataPipes that take in a function as argument"""
704709 input_dp = dp .iter .IterableWrapper (range (10 ))
705- unpicklable_datapipes : List [Tuple [Type [IterDataPipe ], Tuple , Dict [str , Any ]]] = [
706- (dp .iter .Collator , (lambda x : x ,), {}),
707- (dp .iter .Demultiplexer , (2 , lambda x : x % 2 ,), {}),
708- (dp .iter .Filter , (lambda x : x >= 5 ,), {}),
709- (dp .iter .Grouper , (lambda x : x >= 5 ,), {}),
710- (dp .iter .Mapper , (lambda x : x ,), {}),
710+
711+ datapipes_with_lambda_fn : List [Tuple [Type [IterDataPipe ], Tuple , Dict [str , Any ]]] = [
712+ (dp .iter .Collator , (lambda_fn1 ,), {}),
713+ (dp .iter .Demultiplexer , (2 , lambda_fn2 ,), {}),
714+ (dp .iter .Filter , (lambda_fn3 ,), {}),
715+ (dp .iter .Grouper , (lambda_fn3 ,), {}),
716+ (dp .iter .Mapper , (lambda_fn1 ,), {}),
711717 ]
718+
719+ def _local_fns ():
720+ def _fn1 (x ):
721+ return x
722+
723+ def _fn2 (x ):
724+ return x % 2
725+
726+ def _fn3 (x ):
727+ return x >= 5
728+
729+ return _fn1 , _fn2 , _fn3
730+
731+ fn1 , fn2 , fn3 = _local_fns ()
732+
733+ datapipes_with_local_fn : List [Tuple [Type [IterDataPipe ], Tuple , Dict [str , Any ]]] = [
734+ (dp .iter .Collator , (fn1 ,), {}),
735+ (dp .iter .Demultiplexer , (2 , fn2 ,), {}),
736+ (dp .iter .Filter , (fn3 ,), {}),
737+ (dp .iter .Grouper , (fn3 ,), {}),
738+ (dp .iter .Mapper , (fn1 ,), {}),
739+ ]
740+
712741 dp_compare_children = {dp .iter .Demultiplexer }
742+
713743 if HAS_DILL :
714- for dpipe , dp_args , dp_kwargs in unpicklable_datapipes :
744+ for dpipe , dp_args , dp_kwargs in datapipes_with_lambda_fn + datapipes_with_local_fn :
715745 if dpipe in dp_compare_children :
716746 dp1 , dp2 = dpipe (input_dp , * dp_args , ** dp_kwargs ) # type: ignore[call-arg]
717747 self ._serialization_test_for_dp_with_children (dp1 , dp2 , use_dill = True )
718748 else :
719749 datapipe = dpipe (input_dp , * dp_args , ** dp_kwargs ) # type: ignore[call-arg]
720750 self ._serialization_test_for_single_dp (datapipe , use_dill = True )
721751 else :
722- for dpipe , dp_args , dp_kwargs in unpicklable_datapipes :
723- with warnings .catch_warnings (record = True ) as wa :
724- datapipe = dpipe (input_dp , * dp_args , ** dp_kwargs ) # type: ignore[call-arg]
725- self .assertEqual (len (wa ), 1 )
726- self .assertRegex (str (wa [0 ].message ), r"^Lambda function is not supported for pickle" )
727- with self .assertRaises (AttributeError ):
728- p = pickle .dumps (datapipe )
752+ msgs = (
753+ r"^Lambda function is not supported by pickle" ,
754+ r"^Local function is not supported by pickle"
755+ )
756+ for dps , msg in zip ((datapipes_with_lambda_fn , datapipes_with_local_fn ), msgs ):
757+ for dpipe , dp_args , dp_kwargs in dps :
758+ with self .assertWarnsRegex (UserWarning , msg ):
759+ datapipe = dpipe (input_dp , * dp_args , ** dp_kwargs ) # type: ignore[call-arg]
760+ with self .assertRaises ((pickle .PicklingError , AttributeError )):
761+ pickle .dumps (datapipe )
729762
730763 def test_iterable_wrapper_datapipe (self ):
731764
@@ -1145,42 +1178,43 @@ def fn_n1(d0, d1):
11451178 def fn_nn (d0 , d1 ):
11461179 return - d0 , - d1 , d0 + d1
11471180
1148- def _helper (ref_fn , fn , input_col = None , output_col = None ):
1181+ def _helper (ref_fn , fn , input_col = None , output_col = None , error = None ):
11491182 for constr in (list , tuple ):
11501183 datapipe = dp .iter .IterableWrapper ([constr ((0 , 1 , 2 )), constr ((3 , 4 , 5 )), constr ((6 , 7 , 8 ))])
1151- res_dp = datapipe .map (fn , input_col , output_col )
1152- ref_dp = datapipe .map (ref_fn )
1153- self .assertEqual (list (res_dp ), list (ref_dp ))
1154- # Reset
1155- self .assertEqual (list (res_dp ), list (ref_dp ))
1184+ if ref_fn is None :
1185+ with self .assertRaises (error ):
1186+ res_dp = datapipe .map (fn , input_col , output_col )
1187+ list (res_dp )
1188+ else :
1189+ res_dp = datapipe .map (fn , input_col , output_col )
1190+ ref_dp = datapipe .map (ref_fn )
1191+ self .assertEqual (list (res_dp ), list (ref_dp ))
1192+ # Reset
1193+ self .assertEqual (list (res_dp ), list (ref_dp ))
11561194
11571195 # Replacing with one input column and default output column
11581196 _helper (lambda data : (data [0 ], - data [1 ], data [2 ]), fn_11 , 1 )
11591197 _helper (lambda data : (data [0 ], (- data [1 ], data [1 ]), data [2 ]), fn_1n , 1 )
11601198 # The index of input column is out of range
1161- with self .assertRaises (IndexError ):
1162- _helper (None , fn_1n , 3 )
1199+ _helper (None , fn_1n , 3 , error = IndexError )
11631200 # Unmatched input columns with fn arguments
1164- with self . assertRaises ( TypeError ):
1165- _helper ( None , fn_n1 , 1 )
1201+ _helper ( None , fn_n1 , 1 , error = TypeError )
1202+
11661203 # Replacing with multiple input columns and default output column (the left-most input column)
11671204 _helper (lambda data : (data [1 ], data [2 ] + data [0 ]), fn_n1 , [2 , 0 ])
11681205 _helper (lambda data : (data [0 ], (- data [2 ], - data [1 ], data [2 ] + data [1 ])), fn_nn , [2 , 1 ])
11691206
11701207 # output_col can only be specified when input_col is not None
1171- with self .assertRaises (ValueError ):
1172- _helper (None , fn_n1 , None , 1 )
1208+ _helper (None , fn_n1 , None , 1 , error = ValueError )
11731209 # output_col can only be single-element list or tuple
1174- with self .assertRaises (ValueError ):
1175- _helper (None , fn_n1 , None , [0 , 1 ])
1210+ _helper (None , fn_n1 , None , [0 , 1 ], error = ValueError )
11761211 # Single-element list as output_col
11771212 _helper (lambda data : (- data [1 ], data [1 ], data [2 ]), fn_11 , 1 , [0 ])
11781213 # Replacing with one input column and single specified output column
11791214 _helper (lambda data : (- data [1 ], data [1 ], data [2 ]), fn_11 , 1 , 0 )
11801215 _helper (lambda data : (data [0 ], data [1 ], (- data [1 ], data [1 ])), fn_1n , 1 , 2 )
11811216 # The index of output column is out of range
1182- with self .assertRaises (IndexError ):
1183- _helper (None , fn_1n , 1 , 3 )
1217+ _helper (None , fn_1n , 1 , 3 , error = IndexError )
11841218 _helper (lambda data : (data [0 ], data [0 ] + data [2 ], data [2 ]), fn_n1 , [0 , 2 ], 1 )
11851219 _helper (lambda data : ((- data [1 ], - data [2 ], data [1 ] + data [2 ]), data [1 ], data [2 ]), fn_nn , [1 , 2 ], 0 )
11861220
@@ -1213,38 +1247,39 @@ def _dict_update(data, newdata, remove_idx=None):
12131247 del _data [idx ]
12141248 return _data
12151249
1216- def _helper (ref_fn , fn , input_col = None , output_col = None ):
1250+ def _helper (ref_fn , fn , input_col = None , output_col = None , error = None ):
12171251 datapipe = dp .iter .IterableWrapper (
12181252 [{"x" : 0 , "y" : 1 , "z" : 2 },
12191253 {"x" : 3 , "y" : 4 , "z" : 5 },
12201254 {"x" : 6 , "y" : 7 , "z" : 8 }]
12211255 )
1222- res_dp = datapipe .map (fn , input_col , output_col )
1223- ref_dp = datapipe .map (ref_fn )
1224- self .assertEqual (list (res_dp ), list (ref_dp ))
1225- # Reset
1226- self .assertEqual (list (res_dp ), list (ref_dp ))
1256+ if ref_fn is None :
1257+ with self .assertRaises (error ):
1258+ res_dp = datapipe .map (fn , input_col , output_col )
1259+ list (res_dp )
1260+ else :
1261+ res_dp = datapipe .map (fn , input_col , output_col )
1262+ ref_dp = datapipe .map (ref_fn )
1263+ self .assertEqual (list (res_dp ), list (ref_dp ))
1264+ # Reset
1265+ self .assertEqual (list (res_dp ), list (ref_dp ))
12271266
12281267 # Replacing with one input column and default output column
12291268 _helper (lambda data : _dict_update (data , {"y" : - data ["y" ]}), fn_11 , "y" )
12301269 _helper (lambda data : _dict_update (data , {"y" : (- data ["y" ], data ["y" ])}), fn_1n , "y" )
12311270 # The key of input column is not in dict
1232- with self .assertRaises (KeyError ):
1233- _helper (None , fn_1n , "a" )
1271+ _helper (None , fn_1n , "a" , error = KeyError )
12341272 # Unmatched input columns with fn arguments
1235- with self .assertRaises (TypeError ):
1236- _helper (None , fn_n1 , "y" )
1273+ _helper (None , fn_n1 , "y" , error = TypeError )
12371274 # Replacing with multiple input columns and default output column (the left-most input column)
12381275 _helper (lambda data : _dict_update (data , {"z" : data ["x" ] + data ["z" ]}, ["x" ]), fn_n1 , ["z" , "x" ])
12391276 _helper (lambda data : _dict_update (
12401277 data , {"z" : (- data ["z" ], - data ["y" ], data ["y" ] + data ["z" ])}, ["y" ]), fn_nn , ["z" , "y" ])
12411278
12421279 # output_col can only be specified when input_col is not None
1243- with self .assertRaises (ValueError ):
1244- _helper (None , fn_n1 , None , "x" )
1280+ _helper (None , fn_n1 , None , "x" , error = ValueError )
12451281 # output_col can only be single-element list or tuple
1246- with self .assertRaises (ValueError ):
1247- _helper (None , fn_n1 , None , ["x" , "y" ])
1282+ _helper (None , fn_n1 , None , ["x" , "y" ], error = ValueError )
12481283 # Single-element list as output_col
12491284 _helper (lambda data : _dict_update (data , {"x" : - data ["y" ]}), fn_11 , "y" , ["x" ])
12501285 # Replacing with one input column and single specified output column
@@ -1617,24 +1652,41 @@ def test_serializable(self):
16171652 def test_serializable_with_dill (self ):
16181653 """Only for DataPipes that take in a function as argument"""
16191654 input_dp = dp .map .SequenceWrapper (range (10 ))
1620- unpicklable_datapipes : List [
1655+
1656+ datapipes_with_lambda_fn : List [
16211657 Tuple [Type [MapDataPipe ], Tuple , Dict [str , Any ]]
16221658 ] = [
1623- (dp .map .Mapper , (lambda x : x ,), {}),
1659+ (dp .map .Mapper , (lambda_fn1 ,), {}),
16241660 ]
1661+
1662+ def _local_fns ():
1663+ def _fn1 (x ):
1664+ return x
1665+
1666+ return _fn1
1667+
1668+ fn1 = _local_fns ()
1669+
1670+ datapipes_with_local_fn : List [
1671+ Tuple [Type [MapDataPipe ], Tuple , Dict [str , Any ]]
1672+ ] = [
1673+ (dp .map .Mapper , (fn1 ,), {}),
1674+ ]
1675+
16251676 if HAS_DILL :
1626- for dpipe , dp_args , dp_kwargs in unpicklable_datapipes :
1677+ for dpipe , dp_args , dp_kwargs in datapipes_with_lambda_fn + datapipes_with_local_fn :
16271678 _ = dill .dumps (dpipe (input_dp , * dp_args , ** dp_kwargs )) # type: ignore[call-arg]
16281679 else :
1629- for dpipe , dp_args , dp_kwargs in unpicklable_datapipes :
1630- with warnings .catch_warnings (record = True ) as wa :
1631- datapipe = dpipe (input_dp , * dp_args , ** dp_kwargs ) # type: ignore[call-arg]
1632- self .assertEqual (len (wa ), 1 )
1633- self .assertRegex (
1634- str (wa [0 ].message ), r"^Lambda function is not supported for pickle"
1635- )
1636- with self .assertRaises (AttributeError ):
1637- p = pickle .dumps (datapipe )
1680+ msgs = (
1681+ r"^Lambda function is not supported by pickle" ,
1682+ r"^Local function is not supported by pickle"
1683+ )
1684+ for dps , msg in zip ((datapipes_with_lambda_fn , datapipes_with_local_fn ), msgs ):
1685+ for dpipe , dp_args , dp_kwargs in dps :
1686+ with self .assertWarnsRegex (UserWarning , msg ):
1687+ datapipe = dpipe (input_dp , * dp_args , ** dp_kwargs ) # type: ignore[call-arg]
1688+ with self .assertRaises ((pickle .PicklingError , AttributeError )):
1689+ pickle .dumps (datapipe )
16381690
16391691 def test_sequence_wrapper_datapipe (self ):
16401692 seq = list (range (10 ))
0 commit comments