|
1 | 1 | import logging |
2 | | -import os |
3 | 2 | import select |
4 | 3 | import socket |
5 | | -import ssl |
6 | | -from concurrent.futures import ThreadPoolExecutor |
7 | 4 | from typing import Union |
8 | 5 |
|
9 | 6 | from localstack.constants import BIND_HOST, LOCALHOST_IP |
10 | | -from localstack.utils.files import new_tmp_file, save_file |
11 | 7 | from localstack.utils.functions import run_safe |
12 | 8 | from localstack.utils.numbers import is_number |
13 | | -from localstack.utils.serving import Server |
14 | | -from localstack.utils.ssl import create_ssl_cert |
15 | 9 | from localstack.utils.threads import start_worker_thread |
16 | 10 |
|
17 | 11 | LOG = logging.getLogger(__name__) |
@@ -79,98 +73,3 @@ def handle_request(s_src, thread): |
79 | 73 | start_worker_thread(lambda *args, _thread: handle_request(src_socket, _thread)) |
80 | 74 | except socket.timeout: |
81 | 75 | pass |
82 | | - |
83 | | - |
84 | | -def _save_cert_keys(client_cert_key: tuple[str, str]) -> tuple[str, str]: |
85 | | - """ |
86 | | - Save the given cert / key into files and returns their filename |
87 | | - :param client_cert_key: tuple with (client_cert, client_key) |
88 | | - :return: tuple of paths to files containing (client_cert, client_key) |
89 | | - """ |
90 | | - cert_file = client_cert_key[0] |
91 | | - if not os.path.exists(cert_file): |
92 | | - cert_file = new_tmp_file() |
93 | | - save_file(cert_file, client_cert_key[0]) |
94 | | - key_file = client_cert_key[1] |
95 | | - if not os.path.exists(key_file): |
96 | | - key_file = new_tmp_file() |
97 | | - save_file(key_file, client_cert_key[1]) |
98 | | - return cert_file, key_file |
99 | | - |
100 | | - |
101 | | -class TLSProxyServer(Server): |
102 | | - thread_pool: ThreadPoolExecutor |
103 | | - client_certs: tuple[str, str] |
104 | | - socket: socket.socket | None |
105 | | - target_host: str |
106 | | - target_port: str |
107 | | - |
108 | | - def __init__( |
109 | | - self, |
110 | | - port: int, |
111 | | - target: str, |
112 | | - host: str = "localhost", |
113 | | - client_certs: tuple[str, str] | None = None, |
114 | | - ): |
115 | | - super().__init__(port, host) |
116 | | - self.target_host, _, self.target_port = target.partition(":") |
117 | | - self.thread_pool = ThreadPoolExecutor() |
118 | | - self.client_certs = client_certs |
119 | | - self.socket = None |
120 | | - |
121 | | - def _handle_socket(self, source_socket: ssl.SSLSocket, client_address: str) -> None: |
122 | | - LOG.debug( |
123 | | - "Handling connection from %s to %s:%s", |
124 | | - client_address, |
125 | | - self.target_host, |
126 | | - self.target_port, |
127 | | - ) |
128 | | - try: |
129 | | - context = ssl.create_default_context() |
130 | | - context.check_hostname = False |
131 | | - context.verify_mode = ssl.CERT_NONE |
132 | | - if self.client_certs: |
133 | | - LOG.debug("Configuring ssl proxy to use client certs") |
134 | | - cert_file, key_file = _save_cert_keys(client_cert_key=self.client_certs) |
135 | | - context.load_cert_chain(certfile=cert_file, keyfile=key_file) |
136 | | - with socket.create_connection((self.target_host, int(self.target_port))) as sock: |
137 | | - with context.wrap_socket(sock, server_hostname=self.target_host) as target_socket: |
138 | | - sockets = [source_socket, target_socket] |
139 | | - while not self._stopped.is_set(): |
140 | | - s_read, _, _ = select.select(sockets, [], []) |
141 | | - |
142 | | - for s in s_read: |
143 | | - data = s.recv(BUFFER_SIZE) |
144 | | - if not data: |
145 | | - return |
146 | | - |
147 | | - if s == source_socket: |
148 | | - target_socket.sendall(data) |
149 | | - elif s == target_socket: |
150 | | - source_socket.sendall(data) |
151 | | - except Exception as e: |
152 | | - LOG.warning( |
153 | | - "Error while proxying SSL request: %s", e, exc_info=LOG.isEnabledFor(logging.DEBUG) |
154 | | - ) |
155 | | - |
156 | | - def do_run(self): |
157 | | - context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) |
158 | | - |
159 | | - _, cert_file_name, key_file_name = create_ssl_cert() |
160 | | - context.load_cert_chain(cert_file_name, key_file_name) |
161 | | - |
162 | | - with socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0) as sock: |
163 | | - self.socket = sock |
164 | | - sock.bind((self.host, self.port)) |
165 | | - sock.listen() |
166 | | - with context.wrap_socket(sock, server_side=True) as ssock: |
167 | | - while not self._stopped.is_set(): |
168 | | - try: |
169 | | - conn, addr = ssock.accept() |
170 | | - self.thread_pool.submit(self._handle_socket, conn, addr) |
171 | | - except Exception as e: |
172 | | - LOG.exception("Error accepting socket: %s", e) |
173 | | - |
174 | | - def do_shutdown(self): |
175 | | - if self.socket: |
176 | | - self.socket.close() |
0 commit comments