Skip to content

Commit 3ab3b13

Browse files
TikhonJelvisyosefAlsuhaibani
authored andcommitted
fix(rpc): Read RPC responses as binary data in Python (semgrep/semgrep-proprietary#5117)
Change the Python RPC implementation to read from the sub-process's output stream in bytes rather than Unicode characters. All of the IO is now measured in bytes, with explicit encoding/decoding steps to convert to text. The RPC format consists of a length in bytes followed by that many bytes of UTF-8-encoded text. However, the current Python implementation reads data from the process *as text* (`text=True` when starting the process), so `io.read(n)` counts in Unicode characters rather than bytes. When the RPC output includes non-ASCII characters, the number of bytes written in the message header is larger than the number of Unicode characters in the stream. This has not been a problem so far because we only run a single RPC call per process. After the RPC call we close the stream and send an EOF, so `io.read(n)` will read the whole string even if it has `< n` characters. However, this caused a problem when I implemented running multiple RPC calls through a single long-lived process because `io.read(n)` would block indefinitely if the stream did not contain at least `n` characters. This change fixes that problem. Test plan: ran existing tests + reproduced the problem and fix on top of #5066. synced from Pro d507ac7668dcccb43c12dc732a615866d53dc12b
1 parent e60b2d5 commit 3ab3b13

File tree

1 file changed

+9
-10
lines changed

1 file changed

+9
-10
lines changed

cli/src/semgrep/rpc.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@
5959

6060

6161
# Read `size` bytes from `io`. Returns fewer bytes if we hit EOF.
62-
def _really_read(io: IO[str], size: int) -> str:
62+
def _really_read(io: IO[bytes], size: int) -> str:
6363
# Operate on bytes, not str.
6464
out: bytes = b""
6565
while len(out) < size:
@@ -73,23 +73,23 @@ def _really_read(io: IO[str], size: int) -> str:
7373
# clear to me (nmote) whether it is guaranteed to be present on the
7474
# streams provided by subprocess.Popen. So, to be on the safe side,
7575
# we'll just do this ourselves.
76-
new: str = io.read(size)
76+
new: bytes = io.read(size)
7777
# This happens if we hit EOF. In that case, repeatedly reading will lead
7878
# to an infinite loop.
7979
if len(new) == 0:
8080
logger.error(f"0 bytes read from RPC input stream")
8181
break
82-
out = out + new.encode(ENCODING)
82+
out = out + new
8383
# When we read the RPC call for file targeting, we could encounter files
8484
# with non-utf8 characters, in that case we replace them with <?>
8585
# i.e abc.txt -> ab<?>.txt
8686
return out.decode(ENCODING, errors="replace")
8787

8888

89-
def _read_packet(io: IO[str]) -> Optional[str]:
89+
def _read_packet(io: IO[bytes]) -> Optional[str]:
9090
# Unlike `read`, `readline` is guaranteed to return a full line unless there
9191
# is an EOF
92-
size_str = io.readline().strip()
92+
size_str = io.readline().decode(ENCODING).strip()
9393
if not size_str.isdigit():
9494
# Avoid horrific log spew if we somehow got a really long line
9595
truncated = size_str[:50]
@@ -99,12 +99,12 @@ def _read_packet(io: IO[str]) -> Optional[str]:
9999
return _really_read(io, size)
100100

101101

102-
def _write_packet(io: IO[str], packet: str) -> None:
102+
def _write_packet(io: IO[bytes], packet: str) -> None:
103103
# Size in bytes
104104
size: int = len(packet.encode(ENCODING))
105105
size_str = str(size) + "\n"
106-
io.write(size_str)
107-
io.write(packet)
106+
io.write(size_str.encode(ENCODING))
107+
io.write(packet.encode(ENCODING))
108108
io.flush()
109109

110110

@@ -156,8 +156,7 @@ def rpc_call(call: out.FunctionCall, cls: Type[T]) -> Optional[T]:
156156
cmd,
157157
stdin=subprocess.PIPE,
158158
stdout=subprocess.PIPE,
159-
text=True,
160-
encoding=ENCODING,
159+
text=False,
161160
) as proc:
162161
try:
163162
# These need to be local variables because otherwise mypy doesn't

0 commit comments

Comments
 (0)