Skip to content

Commit e963df1

Browse files
committed
Update on "[DataPipe] Separating DataPipes from Dataset into different files"
Separating DataPipes from Dataset into different files. This makes the code more maintainable and simplifies some of the code generation. I have also tried to move `datapipe.py` into `torch.utils.data.datapipes`, but that will lead to circular import and rewriting many import statements. Should I put more time and go down that path some more? Fixes meta-pytorch/data#213 Differential Revision: [D34481962](https://our.internmc.facebook.com/intern/diff/D34481962) [ghstack-poisoned]
2 parents 1a90bc2 + c027c7b commit e963df1

File tree

8 files changed

+69
-32
lines changed

8 files changed

+69
-32
lines changed

torch/utils/data/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from torch.utils.data.dataset import (
1212
ChainDataset,
1313
ConcatDataset,
14-
DataChunk,
1514
Dataset,
1615
IterableDataset,
1716
Subset,
@@ -22,6 +21,7 @@
2221
DFIterDataPipe,
2322
IterDataPipe,
2423
MapDataPipe,
24+
DataChunk,
2525
)
2626
from torch.utils.data.dataloader import (
2727
DataLoader,

torch/utils/data/datapipes/dataframe/structures.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
1-
from torch.utils.data import (
2-
DataChunk,
3-
)
1+
from torch.utils.data.datapipes.datapipe import DataChunk
2+
43

54
class DataChunkDF(DataChunk):
65
"""

torch/utils/data/datapipes/datapipe.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import functools
2-
from typing import Dict, Callable, Optional, TypeVar
2+
from typing import Dict, Callable, Optional, TypeVar, Generic, Iterator
33

44
from torch.utils.data.datapipes._typing import _DataPipeMeta
55
from torch.utils.data._utils.serialization import serialize_fn, SerializationType, deserialize_fn
66
from torch.utils.data.dataset import Dataset, IterableDataset
77

8-
8+
T = TypeVar('T')
99
T_co = TypeVar('T_co', covariant=True)
1010

1111
UNTRACABLE_DATAFRAME_PIPES = ['batch', # As it returns DataChunks
@@ -206,3 +206,21 @@ def __setstate__(self, state_dict):
206206
self.__dict__[k] = deserialize_fn(v)
207207
else:
208208
self.__dict__[k] = v
209+
210+
211+
class DataChunk(list, Generic[T]):
212+
def __init__(self, items):
213+
super().__init__(items)
214+
self.items = items
215+
216+
def as_str(self, indent=''):
217+
res = indent + "[" + ", ".join(str(i) for i in iter(self)) + "]"
218+
return res
219+
220+
def __iter__(self) -> Iterator[T]:
221+
for i in super().__iter__():
222+
yield i
223+
224+
def raw_iterator(self) -> T: # type: ignore[misc]
225+
for i in self.items:
226+
yield i

torch/utils/data/datapipes/datapipe.pyi

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
from torch import Generator as Generator, Tensor as Tensor
77
from torch import default_generator as default_generator, randperm as randperm
88
from torch.utils.data.datapipes._typing import _DataPipeMeta
9-
from typing import Any, Callable, Dict, List, Optional, TypeVar
10-
from torch.utils.data import DataChunk, Dataset, IterableDataset
9+
from typing import Any, Callable, Dict, Generic, Iterator, List, Optional, TypeVar
10+
from torch.utils.data import Dataset, IterableDataset
1111

1212
T_co = TypeVar('T_co', covariant=True)
1313
T = TypeVar('T')
@@ -32,6 +32,7 @@ class MapDataPipe(Dataset[T_co], metaclass=_DataPipeMeta):
3232
# Functional form of 'ZipperMapDataPipe'
3333
def zip(self, *datapipes: MapDataPipe[T_co]) -> MapDataPipe: ...
3434

35+
3536
class IterDataPipe(IterableDataset[T_co], metaclass=_DataPipeMeta):
3637
functions: Dict[str, Callable] = ...
3738
reduce_ex_hook: Optional[Callable] = ...
@@ -76,5 +77,24 @@ class IterDataPipe(IterableDataset[T_co], metaclass=_DataPipeMeta):
7677
# Functional form of 'ZipperIterDataPipe'
7778
def zip(self, *datapipes: IterDataPipe) -> IterDataPipe: ...
7879

80+
7981
class DFIterDataPipe(IterDataPipe):
8082
def _is_dfpipe(self): ...
83+
84+
85+
class DataChunk(list, Generic[T]):
86+
def __init__(self, items):
87+
super().__init__(items)
88+
self.items = items
89+
90+
def as_str(self, indent=''):
91+
res = indent + "[" + ", ".join(str(i) for i in iter(self)) + "]"
92+
return res
93+
94+
def __iter__(self) -> Iterator[T]:
95+
for i in super().__iter__():
96+
yield i
97+
98+
def raw_iterator(self) -> T: # type: ignore[misc]
99+
for i in self.items:
100+
yield i

torch/utils/data/datapipes/datapipe.pyi.in

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
from torch import Generator as Generator, Tensor as Tensor
77
from torch import default_generator as default_generator, randperm as randperm
88
from torch.utils.data.datapipes._typing import _DataPipeMeta
9-
from typing import Any, Callable, Dict, List, Optional, TypeVar
10-
from torch.utils.data import DataChunk, Dataset, IterableDataset
9+
from typing import Any, Callable, Dict, Generic, Iterator, List, Optional, TypeVar
10+
from torch.utils.data import Dataset, IterableDataset
1111

1212
T_co = TypeVar('T_co', covariant=True)
1313
T = TypeVar('T')
@@ -23,6 +23,7 @@ class MapDataPipe(Dataset[T_co], metaclass=_DataPipeMeta):
2323
def register_datapipe_as_function(cls, function_name: Any, cls_to_register: Any): ...
2424
${MapDataPipeMethods}
2525

26+
2627
class IterDataPipe(IterableDataset[T_co], metaclass=_DataPipeMeta):
2728
functions: Dict[str, Callable] = ...
2829
reduce_ex_hook: Optional[Callable] = ...
@@ -40,5 +41,24 @@ class IterDataPipe(IterableDataset[T_co], metaclass=_DataPipeMeta):
4041
def set_reduce_ex_hook(cls, hook_fn: Any) -> None: ...
4142
${IterDataPipeMethods}
4243

44+
4345
class DFIterDataPipe(IterDataPipe):
4446
def _is_dfpipe(self): ...
47+
48+
49+
class DataChunk(list, Generic[T]):
50+
def __init__(self, items):
51+
super().__init__(items)
52+
self.items = items
53+
54+
def as_str(self, indent=''):
55+
res = indent + "[" + ", ".join(str(i) for i in iter(self)) + "]"
56+
return res
57+
58+
def __iter__(self) -> Iterator[T]:
59+
for i in super().__iter__():
60+
yield i
61+
62+
def raw_iterator(self) -> T: # type: ignore[misc]
63+
for i in self.items:
64+
yield i

torch/utils/data/datapipes/iter/grouping.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
from collections import defaultdict
22

3-
from torch.utils.data import DataChunk
43
from torch.utils.data.datapipes._decorator import functional_datapipe
5-
from torch.utils.data.datapipes.datapipe import IterDataPipe
4+
from torch.utils.data.datapipes.datapipe import IterDataPipe, DataChunk
65
from torch.utils.data.datapipes.utils.common import check_lambda_fn
76
from typing import Any, Callable, DefaultDict, Iterator, List, Optional, Sized, TypeVar
87

torch/utils/data/datapipes/map/grouping.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
1-
from torch.utils.data import DataChunk
21
from torch.utils.data.datapipes._decorator import functional_datapipe
3-
from torch.utils.data.datapipes.datapipe import MapDataPipe
2+
from torch.utils.data.datapipes.datapipe import MapDataPipe, DataChunk
43
from typing import List, Optional, Sized, TypeVar
54

65

torch/utils/data/dataset.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -21,24 +21,6 @@
2121
T = TypeVar('T')
2222

2323

24-
class DataChunk(list, Generic[T]):
25-
def __init__(self, items):
26-
super().__init__(items)
27-
self.items = items
28-
29-
def as_str(self, indent=''):
30-
res = indent + "[" + ", ".join(str(i) for i in iter(self)) + "]"
31-
return res
32-
33-
def __iter__(self) -> Iterator[T]:
34-
for i in super().__iter__():
35-
yield i
36-
37-
def raw_iterator(self) -> T: # type: ignore[misc]
38-
for i in self.items:
39-
yield i
40-
41-
4224
class Dataset(Generic[T_co]):
4325
r"""An abstract class representing a :class:`Dataset`.
4426

0 commit comments

Comments
 (0)