Skip to content
This repository was archived by the owner on Jan 18, 2025. It is now read-only.

Commit f79e2d3

Browse files
committed
When refreshing credentials with an stream body, rewind the stream before re-sending the original request.
1 parent d93ed1e commit f79e2d3

3 files changed

Lines changed: 41 additions & 5 deletions

File tree

oauth2client/client.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -554,6 +554,11 @@ def new_request(uri, method='GET', body=None, headers=None,
554554
else:
555555
headers['user-agent'] = self.user_agent
556556

557+
body_stream_position = None
558+
if all(hasattr(body, stream_prop) for stream_prop in ('read',
559+
'seek', 'tell')):
560+
body_stream_position = body.tell()
561+
557562
resp, content = request_orig(uri, method, body, clean_headers(headers),
558563
redirections, connection_type)
559564

@@ -567,6 +572,9 @@ def new_request(uri, method='GET', body=None, headers=None,
567572
refresh_attempt + 1, max_refresh_attempts)
568573
self._refresh(request_orig)
569574
self.apply(headers)
575+
if body_stream_position is not None:
576+
body.seek(body_stream_position)
577+
570578
resp, content = request_orig(uri, method, body, clean_headers(headers),
571579
redirections, connection_type)
572580

tests/http_mock.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -100,17 +100,16 @@ def request(self, uri,
100100
connection_type=None):
101101
resp, content = self._iterable.pop(0)
102102
self.requests.append({'uri': uri, 'body': body, 'headers': headers})
103+
# Read any underlying stream before sending the request.
104+
body_stream_content = body.read() if getattr(body, 'read', None) else None
103105
if content == 'echo_request_headers':
104106
content = headers
105107
elif content == 'echo_request_headers_as_json':
106108
content = json.dumps(headers)
107109
elif content == 'echo_request_body':
108-
if hasattr(body, 'read'):
109-
content = body.read()
110-
else:
111-
content = body
110+
content = body if body_stream_content is None else body_stream_content
112111
elif content == 'echo_request_uri':
113112
content = uri
114113
elif not isinstance(content, bytes):
115-
raise TypeError("http content should be bytes: %r" % (content,))
114+
raise TypeError('http content should be bytes: %r' % (content,))
116115
return httplib2.Response(resp), content

tests/test_file.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@
3232
import unittest
3333

3434
from .http_mock import HttpMockSequence
35+
import six
36+
3537
from oauth2client import file
3638
from oauth2client import locked_file
3739
from oauth2client import multistore_file
@@ -178,6 +180,33 @@ def test_token_refresh_good_store(self):
178180
credentials._refresh(lambda x: x)
179181
self.assertEquals(credentials.access_token, 'bar')
180182

183+
def test_token_refresh_stream_body(self):
184+
expiration = datetime.datetime.utcnow() + datetime.timedelta(minutes=15)
185+
credentials = self.create_test_credentials(expiration=expiration)
186+
187+
s = file.Storage(FILENAME)
188+
s.put(credentials)
189+
credentials = s.get()
190+
new_cred = copy.copy(credentials)
191+
new_cred.access_token = 'bar'
192+
s.put(new_cred)
193+
194+
valid_access_token = '1/3w'
195+
token_response = {'access_token': valid_access_token, 'expires_in': 3600}
196+
http = HttpMockSequence([
197+
({'status': '401'}, b'Initial token expired'),
198+
({'status': '401'}, b'Store token expired'),
199+
({'status': '200'}, json.dumps(token_response).encode('utf-8')),
200+
({'status': '200'}, 'echo_request_body')
201+
])
202+
203+
body = six.StringIO('streaming body')
204+
205+
credentials.authorize(http)
206+
_, content = http.request('https://example.com', body=body)
207+
self.assertEquals(content, 'streaming body')
208+
self.assertEquals(credentials.access_token, valid_access_token)
209+
181210
def test_credentials_delete(self):
182211
credentials = self.create_test_credentials()
183212

0 commit comments

Comments
 (0)