Skip to content

Commit 373109c

Browse files
ejguanfacebook-github-bot
authored andcommitted
Raise warning for unpickable local function (#80232)
Summary: Pull Request resolved: #80232 X-link: meta-pytorch/data#547 Fixes meta-pytorch/data#538 - Improve the validation function to raise warning about unpickable function when either lambda or local function is provided to DataPipe. - The inner function from functools.partial object is extracted as well for validation - Mimic the behavior of pickle module for local lambda function: It would only raise Error for the local function rather than lambda function. So, we will raise warning about local function not lambda function. ```py >>> import pickle >>> def fn(): ... lf = lambda x: x ... pickle.dumps(lf) >>> pickle.dumps(fn) AttributeError: Can't pickle local object 'fn.<locals>.<lambda>' ``` This Diff also fixes the Error introduced by #79344 Test Plan: ``` buck test //caffe2/test:datapipe buck test //pytorch/data/test:tests ``` Tested in OSS ``` # PT pytest test/test_datapipe.py -v # TD pytest test/test_iterdatapipe.py -v pytest test/test_mapdatapipe.py -v pytest test/test_serialization.py -v # TV pytest test/test_prototype_builtin_datasets.py -v ``` Reviewed By: NivekT Differential Revision: D37417556 fbshipit-source-id: 6fae4059285b8c742feda739cc5fe590b2e20c5e
1 parent 590d3e5 commit 373109c

File tree

7 files changed

+163
-71
lines changed

7 files changed

+163
-71
lines changed

test/test_datapipe.py

Lines changed: 108 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -603,6 +603,11 @@ def _worker_init_fn(worker_id):
603603
torch.utils.data.graph_settings.apply_sharding(datapipe, num_workers, worker_id)
604604

605605

606+
lambda_fn1 = lambda x: x # noqa: E731
607+
lambda_fn2 = lambda x: x % 2 # noqa: E731
608+
lambda_fn3 = lambda x: x >= 5 # noqa: E731
609+
610+
606611
class TestFunctionalIterDataPipe(TestCase):
607612

608613
def _serialization_test_helper(self, datapipe, use_dill):
@@ -702,30 +707,58 @@ def test_serializable(self):
702707
def test_serializable_with_dill(self):
703708
"""Only for DataPipes that take in a function as argument"""
704709
input_dp = dp.iter.IterableWrapper(range(10))
705-
unpicklable_datapipes: List[Tuple[Type[IterDataPipe], Tuple, Dict[str, Any]]] = [
706-
(dp.iter.Collator, (lambda x: x,), {}),
707-
(dp.iter.Demultiplexer, (2, lambda x: x % 2,), {}),
708-
(dp.iter.Filter, (lambda x: x >= 5,), {}),
709-
(dp.iter.Grouper, (lambda x: x >= 5,), {}),
710-
(dp.iter.Mapper, (lambda x: x,), {}),
710+
711+
datapipes_with_lambda_fn: List[Tuple[Type[IterDataPipe], Tuple, Dict[str, Any]]] = [
712+
(dp.iter.Collator, (lambda_fn1,), {}),
713+
(dp.iter.Demultiplexer, (2, lambda_fn2,), {}),
714+
(dp.iter.Filter, (lambda_fn3,), {}),
715+
(dp.iter.Grouper, (lambda_fn3,), {}),
716+
(dp.iter.Mapper, (lambda_fn1,), {}),
711717
]
718+
719+
def _local_fns():
720+
def _fn1(x):
721+
return x
722+
723+
def _fn2(x):
724+
return x % 2
725+
726+
def _fn3(x):
727+
return x >= 5
728+
729+
return _fn1, _fn2, _fn3
730+
731+
fn1, fn2, fn3 = _local_fns()
732+
733+
datapipes_with_local_fn: List[Tuple[Type[IterDataPipe], Tuple, Dict[str, Any]]] = [
734+
(dp.iter.Collator, (fn1,), {}),
735+
(dp.iter.Demultiplexer, (2, fn2,), {}),
736+
(dp.iter.Filter, (fn3,), {}),
737+
(dp.iter.Grouper, (fn3,), {}),
738+
(dp.iter.Mapper, (fn1,), {}),
739+
]
740+
712741
dp_compare_children = {dp.iter.Demultiplexer}
742+
713743
if HAS_DILL:
714-
for dpipe, dp_args, dp_kwargs in unpicklable_datapipes:
744+
for dpipe, dp_args, dp_kwargs in datapipes_with_lambda_fn + datapipes_with_local_fn:
715745
if dpipe in dp_compare_children:
716746
dp1, dp2 = dpipe(input_dp, *dp_args, **dp_kwargs) # type: ignore[call-arg]
717747
self._serialization_test_for_dp_with_children(dp1, dp2, use_dill=True)
718748
else:
719749
datapipe = dpipe(input_dp, *dp_args, **dp_kwargs) # type: ignore[call-arg]
720750
self._serialization_test_for_single_dp(datapipe, use_dill=True)
721751
else:
722-
for dpipe, dp_args, dp_kwargs in unpicklable_datapipes:
723-
with warnings.catch_warnings(record=True) as wa:
724-
datapipe = dpipe(input_dp, *dp_args, **dp_kwargs) # type: ignore[call-arg]
725-
self.assertEqual(len(wa), 1)
726-
self.assertRegex(str(wa[0].message), r"^Lambda function is not supported for pickle")
727-
with self.assertRaises(AttributeError):
728-
p = pickle.dumps(datapipe)
752+
msgs = (
753+
r"^Lambda function is not supported by pickle",
754+
r"^Local function is not supported by pickle"
755+
)
756+
for dps, msg in zip((datapipes_with_lambda_fn, datapipes_with_local_fn), msgs):
757+
for dpipe, dp_args, dp_kwargs in dps:
758+
with self.assertWarnsRegex(UserWarning, msg):
759+
datapipe = dpipe(input_dp, *dp_args, **dp_kwargs) # type: ignore[call-arg]
760+
with self.assertRaises((pickle.PicklingError, AttributeError)):
761+
pickle.dumps(datapipe)
729762

730763
def test_iterable_wrapper_datapipe(self):
731764

@@ -1145,42 +1178,43 @@ def fn_n1(d0, d1):
11451178
def fn_nn(d0, d1):
11461179
return -d0, -d1, d0 + d1
11471180

1148-
def _helper(ref_fn, fn, input_col=None, output_col=None):
1181+
def _helper(ref_fn, fn, input_col=None, output_col=None, error=None):
11491182
for constr in (list, tuple):
11501183
datapipe = dp.iter.IterableWrapper([constr((0, 1, 2)), constr((3, 4, 5)), constr((6, 7, 8))])
1151-
res_dp = datapipe.map(fn, input_col, output_col)
1152-
ref_dp = datapipe.map(ref_fn)
1153-
self.assertEqual(list(res_dp), list(ref_dp))
1154-
# Reset
1155-
self.assertEqual(list(res_dp), list(ref_dp))
1184+
if ref_fn is None:
1185+
with self.assertRaises(error):
1186+
res_dp = datapipe.map(fn, input_col, output_col)
1187+
list(res_dp)
1188+
else:
1189+
res_dp = datapipe.map(fn, input_col, output_col)
1190+
ref_dp = datapipe.map(ref_fn)
1191+
self.assertEqual(list(res_dp), list(ref_dp))
1192+
# Reset
1193+
self.assertEqual(list(res_dp), list(ref_dp))
11561194

11571195
# Replacing with one input column and default output column
11581196
_helper(lambda data: (data[0], -data[1], data[2]), fn_11, 1)
11591197
_helper(lambda data: (data[0], (-data[1], data[1]), data[2]), fn_1n, 1)
11601198
# The index of input column is out of range
1161-
with self.assertRaises(IndexError):
1162-
_helper(None, fn_1n, 3)
1199+
_helper(None, fn_1n, 3, error=IndexError)
11631200
# Unmatched input columns with fn arguments
1164-
with self.assertRaises(TypeError):
1165-
_helper(None, fn_n1, 1)
1201+
_helper(None, fn_n1, 1, error=TypeError)
1202+
11661203
# Replacing with multiple input columns and default output column (the left-most input column)
11671204
_helper(lambda data: (data[1], data[2] + data[0]), fn_n1, [2, 0])
11681205
_helper(lambda data: (data[0], (-data[2], -data[1], data[2] + data[1])), fn_nn, [2, 1])
11691206

11701207
# output_col can only be specified when input_col is not None
1171-
with self.assertRaises(ValueError):
1172-
_helper(None, fn_n1, None, 1)
1208+
_helper(None, fn_n1, None, 1, error=ValueError)
11731209
# output_col can only be single-element list or tuple
1174-
with self.assertRaises(ValueError):
1175-
_helper(None, fn_n1, None, [0, 1])
1210+
_helper(None, fn_n1, None, [0, 1], error=ValueError)
11761211
# Single-element list as output_col
11771212
_helper(lambda data: (-data[1], data[1], data[2]), fn_11, 1, [0])
11781213
# Replacing with one input column and single specified output column
11791214
_helper(lambda data: (-data[1], data[1], data[2]), fn_11, 1, 0)
11801215
_helper(lambda data: (data[0], data[1], (-data[1], data[1])), fn_1n, 1, 2)
11811216
# The index of output column is out of range
1182-
with self.assertRaises(IndexError):
1183-
_helper(None, fn_1n, 1, 3)
1217+
_helper(None, fn_1n, 1, 3, error=IndexError)
11841218
_helper(lambda data: (data[0], data[0] + data[2], data[2]), fn_n1, [0, 2], 1)
11851219
_helper(lambda data: ((-data[1], -data[2], data[1] + data[2]), data[1], data[2]), fn_nn, [1, 2], 0)
11861220

@@ -1213,38 +1247,39 @@ def _dict_update(data, newdata, remove_idx=None):
12131247
del _data[idx]
12141248
return _data
12151249

1216-
def _helper(ref_fn, fn, input_col=None, output_col=None):
1250+
def _helper(ref_fn, fn, input_col=None, output_col=None, error=None):
12171251
datapipe = dp.iter.IterableWrapper(
12181252
[{"x": 0, "y": 1, "z": 2},
12191253
{"x": 3, "y": 4, "z": 5},
12201254
{"x": 6, "y": 7, "z": 8}]
12211255
)
1222-
res_dp = datapipe.map(fn, input_col, output_col)
1223-
ref_dp = datapipe.map(ref_fn)
1224-
self.assertEqual(list(res_dp), list(ref_dp))
1225-
# Reset
1226-
self.assertEqual(list(res_dp), list(ref_dp))
1256+
if ref_fn is None:
1257+
with self.assertRaises(error):
1258+
res_dp = datapipe.map(fn, input_col, output_col)
1259+
list(res_dp)
1260+
else:
1261+
res_dp = datapipe.map(fn, input_col, output_col)
1262+
ref_dp = datapipe.map(ref_fn)
1263+
self.assertEqual(list(res_dp), list(ref_dp))
1264+
# Reset
1265+
self.assertEqual(list(res_dp), list(ref_dp))
12271266

12281267
# Replacing with one input column and default output column
12291268
_helper(lambda data: _dict_update(data, {"y": -data["y"]}), fn_11, "y")
12301269
_helper(lambda data: _dict_update(data, {"y": (-data["y"], data["y"])}), fn_1n, "y")
12311270
# The key of input column is not in dict
1232-
with self.assertRaises(KeyError):
1233-
_helper(None, fn_1n, "a")
1271+
_helper(None, fn_1n, "a", error=KeyError)
12341272
# Unmatched input columns with fn arguments
1235-
with self.assertRaises(TypeError):
1236-
_helper(None, fn_n1, "y")
1273+
_helper(None, fn_n1, "y", error=TypeError)
12371274
# Replacing with multiple input columns and default output column (the left-most input column)
12381275
_helper(lambda data: _dict_update(data, {"z": data["x"] + data["z"]}, ["x"]), fn_n1, ["z", "x"])
12391276
_helper(lambda data: _dict_update(
12401277
data, {"z": (-data["z"], -data["y"], data["y"] + data["z"])}, ["y"]), fn_nn, ["z", "y"])
12411278

12421279
# output_col can only be specified when input_col is not None
1243-
with self.assertRaises(ValueError):
1244-
_helper(None, fn_n1, None, "x")
1280+
_helper(None, fn_n1, None, "x", error=ValueError)
12451281
# output_col can only be single-element list or tuple
1246-
with self.assertRaises(ValueError):
1247-
_helper(None, fn_n1, None, ["x", "y"])
1282+
_helper(None, fn_n1, None, ["x", "y"], error=ValueError)
12481283
# Single-element list as output_col
12491284
_helper(lambda data: _dict_update(data, {"x": -data["y"]}), fn_11, "y", ["x"])
12501285
# Replacing with one input column and single specified output column
@@ -1617,24 +1652,41 @@ def test_serializable(self):
16171652
def test_serializable_with_dill(self):
16181653
"""Only for DataPipes that take in a function as argument"""
16191654
input_dp = dp.map.SequenceWrapper(range(10))
1620-
unpicklable_datapipes: List[
1655+
1656+
datapipes_with_lambda_fn: List[
16211657
Tuple[Type[MapDataPipe], Tuple, Dict[str, Any]]
16221658
] = [
1623-
(dp.map.Mapper, (lambda x: x,), {}),
1659+
(dp.map.Mapper, (lambda_fn1,), {}),
16241660
]
1661+
1662+
def _local_fns():
1663+
def _fn1(x):
1664+
return x
1665+
1666+
return _fn1
1667+
1668+
fn1 = _local_fns()
1669+
1670+
datapipes_with_local_fn: List[
1671+
Tuple[Type[MapDataPipe], Tuple, Dict[str, Any]]
1672+
] = [
1673+
(dp.map.Mapper, (fn1,), {}),
1674+
]
1675+
16251676
if HAS_DILL:
1626-
for dpipe, dp_args, dp_kwargs in unpicklable_datapipes:
1677+
for dpipe, dp_args, dp_kwargs in datapipes_with_lambda_fn + datapipes_with_local_fn:
16271678
_ = dill.dumps(dpipe(input_dp, *dp_args, **dp_kwargs)) # type: ignore[call-arg]
16281679
else:
1629-
for dpipe, dp_args, dp_kwargs in unpicklable_datapipes:
1630-
with warnings.catch_warnings(record=True) as wa:
1631-
datapipe = dpipe(input_dp, *dp_args, **dp_kwargs) # type: ignore[call-arg]
1632-
self.assertEqual(len(wa), 1)
1633-
self.assertRegex(
1634-
str(wa[0].message), r"^Lambda function is not supported for pickle"
1635-
)
1636-
with self.assertRaises(AttributeError):
1637-
p = pickle.dumps(datapipe)
1680+
msgs = (
1681+
r"^Lambda function is not supported by pickle",
1682+
r"^Local function is not supported by pickle"
1683+
)
1684+
for dps, msg in zip((datapipes_with_lambda_fn, datapipes_with_local_fn), msgs):
1685+
for dpipe, dp_args, dp_kwargs in dps:
1686+
with self.assertWarnsRegex(UserWarning, msg):
1687+
datapipe = dpipe(input_dp, *dp_args, **dp_kwargs) # type: ignore[call-arg]
1688+
with self.assertRaises((pickle.PicklingError, AttributeError)):
1689+
pickle.dumps(datapipe)
16381690

16391691
def test_sequence_wrapper_datapipe(self):
16401692
seq = list(range(10))

torch/utils/data/datapipes/iter/callable.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from torch.utils.data.datapipes._decorator import functional_datapipe
44
from torch.utils.data._utils.collate import default_collate
55
from torch.utils.data.datapipes.datapipe import IterDataPipe
6-
from torch.utils.data.datapipes.utils.common import _check_lambda_fn
6+
from torch.utils.data.datapipes.utils.common import _check_unpickable_fn
77

88
__all__ = [
99
"CollatorIterDataPipe",
@@ -64,7 +64,7 @@ def __init__(
6464
super().__init__()
6565
self.datapipe = datapipe
6666

67-
_check_lambda_fn(fn)
67+
_check_unpickable_fn(fn)
6868
self.fn = fn # type: ignore[assignment]
6969

7070
self.input_col = input_col

torch/utils/data/datapipes/iter/combining.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from torch.utils.data.datapipes._decorator import functional_datapipe
77
from torch.utils.data.datapipes.datapipe import IterDataPipe
8-
from torch.utils.data.datapipes.utils.common import _check_lambda_fn
8+
from torch.utils.data.datapipes.utils.common import _check_unpickable_fn
99

1010
__all__ = [
1111
"ConcaterIterDataPipe",
@@ -300,7 +300,7 @@ def __new__(cls, datapipe: IterDataPipe, num_instances: int,
300300
if num_instances < 1:
301301
raise ValueError(f"Expected `num_instaces` larger than 0, but {num_instances} is found")
302302

303-
_check_lambda_fn(classifier_fn)
303+
_check_unpickable_fn(classifier_fn)
304304

305305
# When num_instances == 1, demux can be replaced by filter,
306306
# but keep it as Demultiplexer for the sake of consistency

torch/utils/data/datapipes/iter/grouping.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from torch.utils.data.datapipes._decorator import functional_datapipe
44
from torch.utils.data.datapipes.datapipe import IterDataPipe, DataChunk
5-
from torch.utils.data.datapipes.utils.common import _check_lambda_fn
5+
from torch.utils.data.datapipes.utils.common import _check_unpickable_fn
66
from typing import Any, Callable, DefaultDict, Iterator, List, Optional, Sized, TypeVar
77

88
__all__ = [
@@ -215,7 +215,7 @@ def __init__(self,
215215
group_size: Optional[int] = None,
216216
guaranteed_group_size: Optional[int] = None,
217217
drop_remaining: bool = False):
218-
_check_lambda_fn(group_key_fn)
218+
_check_unpickable_fn(group_key_fn)
219219
self.datapipe = datapipe
220220
self.group_key_fn = group_key_fn
221221

torch/utils/data/datapipes/iter/selecting.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,10 @@
33
from torch.utils.data.datapipes._decorator import functional_datapipe
44
from torch.utils.data.datapipes.datapipe import IterDataPipe
55
from torch.utils.data.datapipes.dataframe import dataframe_wrapper as df_wrapper
6-
from torch.utils.data.datapipes.utils.common import _check_lambda_fn, _deprecation_warning
6+
from torch.utils.data.datapipes.utils.common import (
7+
_check_unpickable_fn,
8+
_deprecation_warning,
9+
)
710

811
__all__ = ["FilterIterDataPipe", ]
912

@@ -48,7 +51,7 @@ def __init__(
4851
super().__init__()
4952
self.datapipe = datapipe
5053

51-
_check_lambda_fn(filter_fn)
54+
_check_unpickable_fn(filter_fn)
5255
self.filter_fn = filter_fn # type: ignore[assignment]
5356

5457
if drop_empty_batches is None:

torch/utils/data/datapipes/map/callable.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from torch.utils.data.datapipes.utils.common import _check_lambda_fn
1+
from torch.utils.data.datapipes.utils.common import _check_unpickable_fn
22
from typing import Callable, TypeVar
33
from torch.utils.data.datapipes._decorator import functional_datapipe
44
from torch.utils.data.datapipes.datapipe import MapDataPipe
@@ -48,7 +48,7 @@ def __init__(
4848
) -> None:
4949
super().__init__()
5050
self.datapipe = datapipe
51-
_check_lambda_fn(fn)
51+
_check_unpickable_fn(fn)
5252
self.fn = fn # type: ignore[assignment]
5353

5454
def __len__(self) -> int:

0 commit comments

Comments
 (0)