Skip to content

Commit 40d7c10

Browse files
rohan-varmafacebook-github-bot
authored andcommitted
Unescape string in RPC error message (#49373)
Summary: Pull Request resolved: #49373 Unescaping the string in RPC error message to provide better error msg Test Plan: CI Reviewed By: xush6528 Differential Revision: D25511730 fbshipit-source-id: 054f46d5ffbcb1350012362a023fafb1fe57fca1
1 parent a9137ae commit 40d7c10

File tree

2 files changed

+19
-1
lines changed

2 files changed

+19
-1
lines changed

torch/distributed/rpc/internal.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ def _run_function(python_udf):
201201

202202
def _handle_exception(result):
203203
if isinstance(result, RemoteException):
204-
raise result.exception_type(result.msg)
204+
raise result.exception_type(result.msg.encode("utf-8").decode("unicode_escape"))
205205

206206

207207
def _build_rpc_profiling_key(

torch/testing/_internal/distributed/rpc/rpc_test.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,10 @@ def my_script_func(tensor):
317317
def raise_func():
318318
raise ValueError(expected_err)
319319

320+
expected_err_escape = "\nFirst line of error \n next line of error \n last line of error"
321+
def raise_func_escape():
322+
raise ValueError(expected_err_escape)
323+
320324

321325
global_rref = None
322326

@@ -1982,6 +1986,20 @@ def test_py_raise_in_user_func(self):
19821986
stderr_lines = err.getvalue()
19831987
self.assertTrue(expected_err in stderr_lines)
19841988

1989+
@dist_init
1990+
def test_py_raise_in_user_func_escaped_str(self):
1991+
n = self.rank + 1
1992+
dst_rank = n % self.world_size
1993+
fut = rpc.rpc_async(worker_name(dst_rank), raise_func_escape)
1994+
try:
1995+
fut.wait()
1996+
except ValueError as e:
1997+
msg = str(e)
1998+
# Ensure newlines are unescaped to provide a better repr of error.
1999+
self.assertEqual(msg, msg.encode("utf-8").decode("unicode_escape"))
2000+
else:
2001+
self.assertTrue(False, "expected raise_func_escape to raise ValueError.")
2002+
19852003
@dist_init
19862004
def test_nested_rpc(self):
19872005
n = self.rank + 1

0 commit comments

Comments
 (0)