|
4 | 4 | import socket |
5 | 5 | import sys |
6 | 6 | import unittest |
| 7 | +import weakref |
7 | 8 | from unittest import mock |
8 | 9 | try: |
9 | 10 | import ssl |
@@ -274,6 +275,72 @@ async def client(addr): |
274 | 275 | self.loop.run_until_complete( |
275 | 276 | asyncio.wait_for(client(srv.addr), timeout=10)) |
276 | 277 |
|
| 278 | + # No garbage is left if SSL is closed uncleanly |
| 279 | + client_context = weakref.ref(client_context) |
| 280 | + self.assertIsNone(client_context()) |
| 281 | + |
| 282 | + def test_create_connection_memory_leak(self): |
| 283 | + HELLO_MSG = b'1' * self.PAYLOAD_SIZE |
| 284 | + |
| 285 | + server_context = test_utils.simple_server_sslcontext() |
| 286 | + client_context = test_utils.simple_client_sslcontext() |
| 287 | + |
| 288 | + def serve(sock): |
| 289 | + sock.settimeout(self.TIMEOUT) |
| 290 | + |
| 291 | + sock.start_tls(server_context, server_side=True) |
| 292 | + |
| 293 | + sock.sendall(b'O') |
| 294 | + data = sock.recv_all(len(HELLO_MSG)) |
| 295 | + self.assertEqual(len(data), len(HELLO_MSG)) |
| 296 | + |
| 297 | + sock.shutdown(socket.SHUT_RDWR) |
| 298 | + sock.close() |
| 299 | + |
| 300 | + class ClientProto(asyncio.Protocol): |
| 301 | + def __init__(self, on_data, on_eof): |
| 302 | + self.on_data = on_data |
| 303 | + self.on_eof = on_eof |
| 304 | + self.con_made_cnt = 0 |
| 305 | + |
| 306 | + def connection_made(proto, tr): |
| 307 | + # XXX: We assume user stores the transport in protocol |
| 308 | + proto.tr = tr |
| 309 | + proto.con_made_cnt += 1 |
| 310 | + # Ensure connection_made gets called only once. |
| 311 | + self.assertEqual(proto.con_made_cnt, 1) |
| 312 | + |
| 313 | + def data_received(self, data): |
| 314 | + self.on_data.set_result(data) |
| 315 | + |
| 316 | + def eof_received(self): |
| 317 | + self.on_eof.set_result(True) |
| 318 | + |
| 319 | + async def client(addr): |
| 320 | + await asyncio.sleep(0.5) |
| 321 | + |
| 322 | + on_data = self.loop.create_future() |
| 323 | + on_eof = self.loop.create_future() |
| 324 | + |
| 325 | + tr, proto = await self.loop.create_connection( |
| 326 | + lambda: ClientProto(on_data, on_eof), *addr, |
| 327 | + ssl=client_context) |
| 328 | + |
| 329 | + self.assertEqual(await on_data, b'O') |
| 330 | + tr.write(HELLO_MSG) |
| 331 | + await on_eof |
| 332 | + |
| 333 | + tr.close() |
| 334 | + |
| 335 | + with self.tcp_server(serve, timeout=self.TIMEOUT) as srv: |
| 336 | + self.loop.run_until_complete( |
| 337 | + asyncio.wait_for(client(srv.addr), timeout=10)) |
| 338 | + |
| 339 | + # No garbage is left for SSL client from loop.create_connection, even |
| 340 | + # if user stores the SSLTransport in corresponding protocol instance |
| 341 | + client_context = weakref.ref(client_context) |
| 342 | + self.assertIsNone(client_context()) |
| 343 | + |
277 | 344 | def test_start_tls_client_buf_proto_1(self): |
278 | 345 | HELLO_MSG = b'1' * self.PAYLOAD_SIZE |
279 | 346 |
|
@@ -562,6 +629,11 @@ async def client(addr): |
562 | 629 | # exception or log an error, even if the handshake failed |
563 | 630 | self.assertEqual(messages, []) |
564 | 631 |
|
| 632 | + # The 10s handshake timeout should be cancelled to free related |
| 633 | + # objects without really waiting for 10s |
| 634 | + client_sslctx = weakref.ref(client_sslctx) |
| 635 | + self.assertIsNone(client_sslctx()) |
| 636 | + |
565 | 637 | def test_create_connection_ssl_slow_handshake(self): |
566 | 638 | client_sslctx = test_utils.simple_client_sslcontext() |
567 | 639 |
|
|
0 commit comments