-
Notifications
You must be signed in to change notification settings - Fork 17
Expand file tree
/
Copy pathserialize.py
More file actions
336 lines (249 loc) · 10.6 KB
/
serialize.py
File metadata and controls
336 lines (249 loc) · 10.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
"""Serialization functions."""
from __future__ import annotations
import io
import pickle
import sys
from collections import OrderedDict
from collections.abc import Sized
from typing import Any
from typing import Protocol
from typing import runtime_checkable
if sys.version_info >= (3, 11): # pragma: >=3.11 cover
from typing import TypeAlias
else: # pragma: <3.11 cover
from typing import TypeAlias
if sys.version_info >= (3, 12): # pragma: >=3.12 cover
from typing import TypeGuard
else: # pragma: <3.12 cover
from typing import TypeGuard
import cloudpickle
# Pickle protocol 5 is available in Python 3.8 version so that is ProxyStore's
# minimum version. If higher version come out in the future, prefer those.
_PICKLE_PROTOCOL = max(pickle.HIGHEST_PROTOCOL, 5)
if sys.version_info >= (3, 12): # pragma: >=3.12 cover
from collections.abc import Buffer
@runtime_checkable
class BytesLike(Buffer, Sized, Protocol):
"""Protocol for bytes-like objects."""
pass
else: # pragma: <3.12 cover
BytesLike: TypeAlias = bytes | bytearray | memoryview
"""Protocol for bytes-like objects."""
class SerializationError(Exception):
"""Base Serialization Exception."""
pass
class _Serializer(Protocol):
"""Serializer protocol.
The `identifier` attribute, by convention, is a two-byte string containing
a unique identifier for the serializer type. The name is the human-readable
name of the serializer used in logging and error messages.
"""
identifier: bytes
name: str
def supported(self, obj: Any) -> bool:
"""Check if the serializer is compatible with the object.
The `supported` check is designed to be a fast way to determine if this
serializer may be compatible with a given `obj`. If `supported(obj)`
returns `False`, then it is guaranteed that `serialize(obj)` will
fail. However, the contrapositive is not true. `serialize(obj)`
can still fail even if `supported(obj)` return `True`.
"""
...
def serialize(self, obj: Any, buffer: io.BytesIO) -> None:
"""Serialize the object and write to a buffer."""
...
def deserialize(self, buffer: io.BytesIO) -> Any:
"""Deserialize bytes from a buffer to an object."""
...
class _BytesSerializer:
identifier = b'BS'
name = 'bytes'
def supported(self, obj: Any) -> bool:
return isinstance(obj, bytes)
def serialize(self, obj: Any, buffer: io.BytesIO) -> None:
buffer.write(obj)
def deserialize(self, buffer: io.BytesIO) -> Any:
return buffer.read()
class _StrSerializer:
identifier = b'US'
name = 'string'
def supported(self, obj: Any) -> bool:
return isinstance(obj, str)
def serialize(self, obj: Any, buffer: io.BytesIO) -> None:
buffer.write(obj.encode())
def deserialize(self, buffer: io.BytesIO) -> Any:
return buffer.read().decode()
class _NumpySerializer:
identifier = b'NP'
name = 'numpy'
def supported(self, obj: Any) -> bool:
return isinstance(obj, numpy.ndarray)
def serialize(self, obj: Any, buffer: io.BytesIO) -> None:
# Must allow_pickle=True for the case where the numpy array contains
# non-numeric data.
numpy.save(buffer, obj, allow_pickle=True)
def deserialize(self, buffer: io.BytesIO) -> Any:
return numpy.load(buffer, allow_pickle=True)
class _PandasSerializer:
identifier = b'PD'
name = 'pandas'
def supported(self, obj: Any) -> bool:
return isinstance(obj, pandas.DataFrame)
def serialize(self, obj: Any, buffer: io.BytesIO) -> None:
# Pandas with pickle protocol 5 is the suggested serialization
# method for best efficiency. We tested feather IPC and parquet and
# both were slower than pickle.
# https://github.com/dask/distributed/issues/614#issuecomment-631033227
obj.to_pickle(buffer, protocol=_PICKLE_PROTOCOL)
def deserialize(self, buffer: io.BytesIO) -> Any:
return pandas.read_pickle(buffer)
class _PolarsSerializer:
identifier = b'PL'
name = 'polars'
def supported(self, obj: Any) -> bool:
return isinstance(obj, polars.DataFrame)
def serialize(self, obj: Any, buffer: io.BytesIO) -> None:
obj.write_ipc(buffer)
def deserialize(self, buffer: io.BytesIO) -> Any:
return polars.read_ipc(buffer.read())
class _PickleSerializer:
identifier = b'PK'
name = 'pickle'
def supported(self, obj: Any) -> bool:
# Assume this serializer can handle any type. This is not explicitly
# true but checking every exception is non-trivial and essentially
# requires attempting serialization and seeing if it fails.
return True
def serialize(self, obj: Any, buffer: io.BytesIO) -> None:
pickle.dump(obj, buffer, protocol=_PICKLE_PROTOCOL)
def deserialize(self, buffer: io.BytesIO) -> Any:
return pickle.load(buffer)
class _CloudPickleSerializer:
identifier = b'CP'
name = 'cloudpickle'
def supported(self, obj: Any) -> bool:
# Assume this serializer can handle any type. This is not explicitly
# true but checking every exception is non-trivial and essentially
# requires attempting serialization and seeing if it fails.
return True
def serialize(self, obj: Any, buffer: io.BytesIO) -> None:
cloudpickle.dump(obj, buffer, protocol=_PICKLE_PROTOCOL)
def deserialize(self, buffer: io.BytesIO) -> Any:
return cloudpickle.load(buffer)
_SERIALIZERS: dict[bytes, _Serializer] = OrderedDict()
def _register_serializer(serializer: type[_Serializer]) -> None:
if serializer.identifier in _SERIALIZERS:
current = _SERIALIZERS[serializer.identifier]
raise AssertionError(
f'Serializer named {current.name!r} with identifier '
f'{current.identifier!r} already exists.',
)
_SERIALIZERS[serializer.identifier] = serializer()
# Registration order determines priority so we register in the order
# we want serialization to be tried.
_register_serializer(_BytesSerializer)
_register_serializer(_StrSerializer)
try:
import numpy
_register_serializer(_NumpySerializer)
except ImportError: # pragma: no cover
pass
try:
import pandas
_register_serializer(_PandasSerializer)
except ImportError: # pragma: no cover
pass
try:
import polars
_register_serializer(_PolarsSerializer)
except ImportError: # pragma: no cover
pass
_register_serializer(_PickleSerializer)
_register_serializer(_CloudPickleSerializer)
def is_bytes_like(obj: Any) -> TypeGuard[BytesLike]:
"""Check if the object is bytes-like."""
if sys.version_info >= (3, 12): # pragma: >=3.12 cover
return isinstance(obj, BytesLike)
else: # pragma: <3.12 cover
return isinstance(obj, (bytes, bytearray, memoryview))
def serialize(obj: Any) -> bytes:
"""Serialize object.
Objects are serialized with different mechanisms depending on their type.
- [bytes][] types are not serialized.
- [str][] types are encoded to bytes.
- [numpy.ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html){target=_blank}
types are serialized using
[numpy.save](https://numpy.org/doc/stable/reference/generated/numpy.save.html){target=_blank}.
- [pandas.DataFrame](https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.html){target=_blank}
types are serialized using
[to_pickle](https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.to_pickle.html){target=_blank}.
- [polars.DataFrame](https://pola-rs.github.io/polars/py-polars/html/reference/dataframe/index.html){target=_blank}
types are serialized using
[write_ipc](https://docs.pola.rs/api/python/stable/reference/api/polars.DataFrame.write_ipc.html){target=_blank}.
- Other types are
[pickled](https://docs.python.org/3/library/pickle.html){target=_blank}.
If pickle fails,
[cloudpickle](https://github.com/cloudpipe/cloudpickle){target=_blank}
is used as a fallback.
Args:
obj: Object to serialize.
Returns:
Bytes-like object that can be passed to \
[`deserialize()`][proxystore.serialize.deserialize].
Raises:
SerializationError: If serializing the object fails with all available
serializers. Cloudpickle is the last resort, so this error will
typically be raised from a cloudpickle error.
"""
last_exception: Exception | None = None
for identifier, serializer in _SERIALIZERS.items():
if serializer.supported(obj):
try:
buffer = io.BytesIO()
buffer.write(identifier + b'\n')
serializer.serialize(obj, buffer)
return buffer.getvalue()
except Exception as e:
last_exception = e
assert last_exception is not None
raise SerializationError(
f'Object of type {type(obj)} is not supported.',
) from last_exception
def deserialize(buffer: BytesLike) -> Any:
"""Deserialize object.
Warning:
Pickled data is not secure, and malicious pickled object can execute
arbitrary code when upickled. Only unpickle data you trust.
Args:
buffer: Bytes-like object produced by
[`serialize()`][proxystore.serialize.serialize].
Returns:
The deserialized object.
Raises:
ValueError: If `buffer` is not bytes-like.
SerializationError: If the identifier of `data` is missing or
invalid. The identifier is prepended to the string in
[`serialize()`][proxystore.serialize.serialize] to indicate which
serialization method was used (e.g., no serialization, pickle,
etc.).
SerializationError: If pickle or cloudpickle raise an exception
when deserializing the object.
"""
if not is_bytes_like(buffer):
raise ValueError(
f'Expected data to be a bytes-like type, not {type(buffer)}.',
)
with io.BytesIO(buffer) as buffer_io:
identifier = buffer_io.readline().strip()
if identifier not in _SERIALIZERS:
raise SerializationError(
f'Unknown identifier {identifier!r} for deserialization.',
)
serializer = _SERIALIZERS[identifier]
try:
return serializer.deserialize(buffer_io)
except Exception as e:
raise SerializationError(
'Failed to deserialize object using the '
f'{serializer.name} serializer.',
) from e