|
4 | 4 | import socket |
5 | 5 | import sys |
6 | 6 | import unittest |
| 7 | +import weakref |
7 | 8 | from unittest import mock |
8 | 9 | try: |
9 | 10 | import ssl |
@@ -270,6 +271,72 @@ async def client(addr): |
270 | 271 | self.loop.run_until_complete( |
271 | 272 | asyncio.wait_for(client(srv.addr), loop=self.loop, timeout=10)) |
272 | 273 |
|
| 274 | + # No garbage is left if SSL is closed uncleanly |
| 275 | + client_context = weakref.ref(client_context) |
| 276 | + self.assertIsNone(client_context()) |
| 277 | + |
| 278 | + def test_create_connection_memory_leak(self): |
| 279 | + HELLO_MSG = b'1' * self.PAYLOAD_SIZE |
| 280 | + |
| 281 | + server_context = test_utils.simple_server_sslcontext() |
| 282 | + client_context = test_utils.simple_client_sslcontext() |
| 283 | + |
| 284 | + def serve(sock): |
| 285 | + sock.settimeout(self.TIMEOUT) |
| 286 | + |
| 287 | + sock.start_tls(server_context, server_side=True) |
| 288 | + |
| 289 | + sock.sendall(b'O') |
| 290 | + data = sock.recv_all(len(HELLO_MSG)) |
| 291 | + self.assertEqual(len(data), len(HELLO_MSG)) |
| 292 | + |
| 293 | + sock.shutdown(socket.SHUT_RDWR) |
| 294 | + sock.close() |
| 295 | + |
| 296 | + class ClientProto(asyncio.Protocol): |
| 297 | + def __init__(self, on_data, on_eof): |
| 298 | + self.on_data = on_data |
| 299 | + self.on_eof = on_eof |
| 300 | + self.con_made_cnt = 0 |
| 301 | + |
| 302 | + def connection_made(proto, tr): |
| 303 | + # XXX: We assume user stores the transport in protocol |
| 304 | + proto.tr = tr |
| 305 | + proto.con_made_cnt += 1 |
| 306 | + # Ensure connection_made gets called only once. |
| 307 | + self.assertEqual(proto.con_made_cnt, 1) |
| 308 | + |
| 309 | + def data_received(self, data): |
| 310 | + self.on_data.set_result(data) |
| 311 | + |
| 312 | + def eof_received(self): |
| 313 | + self.on_eof.set_result(True) |
| 314 | + |
| 315 | + async def client(addr): |
| 316 | + await asyncio.sleep(0.5) |
| 317 | + |
| 318 | + on_data = self.loop.create_future() |
| 319 | + on_eof = self.loop.create_future() |
| 320 | + |
| 321 | + tr, proto = await self.loop.create_connection( |
| 322 | + lambda: ClientProto(on_data, on_eof), *addr, |
| 323 | + ssl=client_context) |
| 324 | + |
| 325 | + self.assertEqual(await on_data, b'O') |
| 326 | + tr.write(HELLO_MSG) |
| 327 | + await on_eof |
| 328 | + |
| 329 | + tr.close() |
| 330 | + |
| 331 | + with self.tcp_server(serve, timeout=self.TIMEOUT) as srv: |
| 332 | + self.loop.run_until_complete( |
| 333 | + asyncio.wait_for(client(srv.addr), timeout=10)) |
| 334 | + |
| 335 | + # No garbage is left for SSL client from loop.create_connection, even |
| 336 | + # if user stores the SSLTransport in corresponding protocol instance |
| 337 | + client_context = weakref.ref(client_context) |
| 338 | + self.assertIsNone(client_context()) |
| 339 | + |
273 | 340 | def test_start_tls_client_buf_proto_1(self): |
274 | 341 | HELLO_MSG = b'1' * self.PAYLOAD_SIZE |
275 | 342 |
|
@@ -560,6 +627,11 @@ async def client(addr): |
560 | 627 | # exception or log an error, even if the handshake failed |
561 | 628 | self.assertEqual(messages, []) |
562 | 629 |
|
| 630 | + # The 10s handshake timeout should be cancelled to free related |
| 631 | + # objects without really waiting for 10s |
| 632 | + client_sslctx = weakref.ref(client_sslctx) |
| 633 | + self.assertIsNone(client_sslctx()) |
| 634 | + |
563 | 635 | def test_create_connection_ssl_slow_handshake(self): |
564 | 636 | client_sslctx = test_utils.simple_client_sslcontext() |
565 | 637 |
|
|
0 commit comments