Skip to content

Commit caa6bdb

Browse files
committed
refactor MultipartWriter to use Payload
1 parent 7e9a381 commit caa6bdb

File tree

11 files changed

+786
-862
lines changed

11 files changed

+786
-862
lines changed

aiohttp/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from . import hdrs # noqa
66
from .client import * # noqa
7+
from .formdata import * # noqa
78
from .helpers import * # noqa
89
from .http_message import HttpVersion, HttpVersion10, HttpVersion11 # noqa
910
from .http_websocket import WSMsgType, WSCloseCode, WSMessage, WebSocketError # noqa
@@ -25,11 +26,12 @@
2526

2627

2728
__all__ = (client.__all__ + # noqa
29+
formdata.__all__ + # noqa
2830
helpers.__all__ + # noqa
29-
streams.__all__ + # noqa
31+
multipart.__all__ + # noqa
3032
payload.__all__ + # noqa
3133
payload_streamer.__all__ + # noqa
32-
multipart.__all__ + # noqa
34+
streams.__all__ + # noqa
3335
('hdrs', 'FileSender',
3436
'HttpVersion', 'HttpVersion10', 'HttpVersion11',
3537
'WSMsgType', 'MsgType', 'WSCloseCode',

aiohttp/client_reqrep.py

Lines changed: 34 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@
1212
import aiohttp
1313

1414
from . import hdrs, helpers, http, payload
15+
from .formdata import FormData
1516
from .helpers import PY_35, HeadersMixin, SimpleCookie, _TimeServiceTimeoutNoop
1617
from .http import HttpMessage
1718
from .log import client_logger
18-
from .multipart import MultipartWriter
1919
from .streams import FlowControlStreamReader
2020

2121
try:
@@ -217,71 +217,54 @@ def update_auth(self, auth):
217217

218218
self.headers[hdrs.AUTHORIZATION] = auth.encode()
219219

220-
def update_body_from_data(self, data, skip_auto_headers):
221-
if not data:
222-
return
223-
224-
try:
225-
self.body = payload.PAYLOAD_REGISTRY.get(data)
226-
227-
# enable chunked encoding if needed
228-
if not self.chunked:
229-
if hdrs.CONTENT_LENGTH not in self.headers:
230-
size = self.body.size
231-
if size is None:
232-
self.chunked = True
233-
else:
234-
if hdrs.CONTENT_LENGTH not in self.headers:
235-
self.headers[hdrs.CONTENT_LENGTH] = str(size)
236-
237-
# set content-type
238-
if (hdrs.CONTENT_TYPE not in self.headers and
239-
hdrs.CONTENT_TYPE not in skip_auto_headers):
240-
self.headers[hdrs.CONTENT_TYPE] = self.body.content_type
241-
242-
# copy payload headers
243-
if self.body.headers:
244-
for (key, value) in self.body.headers.items():
245-
if key not in self.headers:
246-
self.headers[key] = value
247-
248-
except payload.LookupError:
249-
pass
250-
else:
220+
def update_body_from_data(self, body, skip_auto_headers):
221+
if not body:
251222
return
252223

253-
if asyncio.iscoroutine(data):
224+
if asyncio.iscoroutine(body):
254225
warnings.warn(
255226
'coroutine as data object is deprecated, '
256227
'use aiohttp.streamer #1664',
257228
DeprecationWarning, stacklevel=2)
258229

259-
self.body = data
230+
self.body = body
260231
if (hdrs.CONTENT_LENGTH not in self.headers and
261232
self.chunked is None):
262233
self.chunked = True
263234

264-
elif isinstance(data, MultipartWriter):
265-
self.body = data.serialize()
266-
self.headers.update(data.headers)
267-
self.chunked = True
235+
return
268236

269-
else:
270-
if not isinstance(data, helpers.FormData):
271-
data = helpers.FormData(data)
237+
# FormData
238+
if isinstance(body, FormData):
239+
body = body(self.encoding)
272240

273-
self.body = data(self.encoding)
241+
try:
242+
body = payload.PAYLOAD_REGISTRY.get(body)
243+
except payload.LookupError:
244+
body = FormData(body)(self.encoding)
274245

275-
if (hdrs.CONTENT_TYPE not in self.headers and
276-
hdrs.CONTENT_TYPE not in skip_auto_headers):
277-
self.headers[hdrs.CONTENT_TYPE] = data.content_type
246+
self.body = body
278247

279-
if data.is_multipart:
280-
self.chunked = True
281-
else:
282-
if (hdrs.CONTENT_LENGTH not in self.headers and
283-
not self.chunked):
284-
self.headers[hdrs.CONTENT_LENGTH] = str(len(self.body))
248+
# enable chunked encoding if needed
249+
if not self.chunked:
250+
if hdrs.CONTENT_LENGTH not in self.headers:
251+
size = body.size
252+
if size is None:
253+
self.chunked = True
254+
else:
255+
if hdrs.CONTENT_LENGTH not in self.headers:
256+
self.headers[hdrs.CONTENT_LENGTH] = str(size)
257+
258+
# set content-type
259+
if (hdrs.CONTENT_TYPE not in self.headers and
260+
hdrs.CONTENT_TYPE not in skip_auto_headers):
261+
self.headers[hdrs.CONTENT_TYPE] = body.content_type
262+
263+
# copy payload headers
264+
if body.headers:
265+
for (key, value) in body.headers.items():
266+
if key not in self.headers:
267+
self.headers[key] = value
285268

286269
def update_transfer_encoding(self):
287270
"""Analyze transfer-encoding header."""

aiohttp/formdata.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
import io
2+
from urllib.parse import urlencode
3+
4+
from multidict import MultiDict, MultiDictProxy
5+
6+
from . import hdrs, multipart, payload
7+
from .helpers import guess_filename
8+
9+
__all__ = ('FormData',)
10+
11+
12+
class FormData:
13+
"""Helper class for multipart/form-data and
14+
application/x-www-form-urlencoded body generation."""
15+
16+
def __init__(self, fields=(), quote_fields=True):
17+
self._writer = multipart.MultipartWriter('form-data')
18+
self._fields = []
19+
self._is_multipart = False
20+
self._quote_fields = quote_fields
21+
22+
if isinstance(fields, dict):
23+
fields = list(fields.items())
24+
elif not isinstance(fields, (list, tuple)):
25+
fields = (fields,)
26+
self.add_fields(*fields)
27+
28+
def add_field(self, name, value, *, content_type=None, filename=None,
29+
content_transfer_encoding=None):
30+
31+
if isinstance(value, io.IOBase):
32+
self._is_multipart = True
33+
elif isinstance(value, (bytes, bytearray, memoryview)):
34+
if filename is None and content_transfer_encoding is None:
35+
filename = name
36+
37+
type_options = MultiDict({'name': name})
38+
if filename is not None and not isinstance(filename, str):
39+
raise TypeError('filename must be an instance of str. '
40+
'Got: %s' % filename)
41+
if filename is None and isinstance(value, io.IOBase):
42+
filename = guess_filename(value, name)
43+
if filename is not None:
44+
type_options['filename'] = filename
45+
self._is_multipart = True
46+
47+
headers = {}
48+
if content_type is not None:
49+
if not isinstance(content_type, str):
50+
raise TypeError('content_type must be an instance of str. '
51+
'Got: %s' % content_type)
52+
headers[hdrs.CONTENT_TYPE] = content_type
53+
self._is_multipart = True
54+
if content_transfer_encoding is not None:
55+
if not isinstance(content_transfer_encoding, str):
56+
raise TypeError('content_transfer_encoding must be an instance'
57+
' of str. Got: %s' % content_transfer_encoding)
58+
headers[hdrs.CONTENT_TRANSFER_ENCODING] = content_transfer_encoding
59+
self._is_multipart = True
60+
61+
self._fields.append((type_options, headers, value))
62+
63+
def add_fields(self, *fields):
64+
to_add = list(fields)
65+
66+
while to_add:
67+
rec = to_add.pop(0)
68+
69+
if isinstance(rec, io.IOBase):
70+
k = guess_filename(rec, 'unknown')
71+
self.add_field(k, rec)
72+
73+
elif isinstance(rec, (MultiDictProxy, MultiDict)):
74+
to_add.extend(rec.items())
75+
76+
elif isinstance(rec, (list, tuple)) and len(rec) == 2:
77+
k, fp = rec
78+
self.add_field(k, fp)
79+
80+
else:
81+
raise TypeError('Only io.IOBase, multidict and (name, file) '
82+
'pairs allowed, use .add_field() for passing '
83+
'more complex parameters, got {!r}'
84+
.format(rec))
85+
86+
def _gen_form_urlencoded(self, encoding):
87+
# form data (x-www-form-urlencoded)
88+
data = []
89+
for type_options, _, value in self._fields:
90+
data.append((type_options['name'], value))
91+
92+
return payload.BytesPayload(
93+
urlencode(data, doseq=True).encode(encoding),
94+
content_type='application/x-www-form-urlencoded')
95+
96+
def _gen_form_data(self, encoding):
97+
"""Encode a list of fields using the multipart/form-data MIME format"""
98+
for dispparams, headers, value in self._fields:
99+
if hdrs.CONTENT_TYPE in headers:
100+
part = payload.get_payload(
101+
value, content_type=headers[hdrs.CONTENT_TYPE],
102+
headers=headers, encoding=encoding)
103+
else:
104+
part = payload.get_payload(
105+
value, headers=headers, encoding=encoding)
106+
if dispparams:
107+
part.set_content_disposition(
108+
'form-data', quote_fields=self._quote_fields, **dispparams
109+
)
110+
# FIXME cgi.FieldStorage doesn't likes body parts with
111+
# Content-Length which were sent via chunked transfer encoding
112+
part.headers.pop(hdrs.CONTENT_LENGTH, None)
113+
114+
self._writer.append_payload(part)
115+
116+
return self._writer
117+
118+
def __call__(self, encoding):
119+
if self._is_multipart:
120+
return self._gen_form_data(encoding)
121+
else:
122+
return self._gen_form_urlencoded(encoding)

0 commit comments

Comments
 (0)