Skip to content

Commit 7518f54

Browse files
gmagogsfmfacebook-github-bot
authored andcommitted
Add flag torch_jit_disable_warning_prints to allow disabling all warnings.warn (#49313)
Summary: Adding a flag torch_jit_disable_warning_prints to optimize interpreter performance by suppressing (potentially large amount) of warnings.warn. This is to work around TorchScript's warning behavior mismatch with Python. Python by default triggers a warning once per location but TorchScript doesn't support it. This causes same warning to trigger and print once per inference run, hurting performance. Pull Request resolved: #49313 Reviewed By: SplitInfinity Differential Revision: D25534274 Pulled By: gmagogsfm fbshipit-source-id: eaeb57a335c3e6c7eb259671645db05d781e80a2
1 parent aff0b68 commit 7518f54

File tree

3 files changed

+11
-0
lines changed

3 files changed

+11
-0
lines changed

torch/csrc/jit/runtime/interpreter.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -891,6 +891,10 @@ struct CodeImpl {
891891
}
892892

893893
void emitWarn(Node* node) {
894+
if (FLAGS_torch_jit_disable_warning_prints) {
895+
return;
896+
}
897+
894898
emitLoadInputs(node->inputs());
895899
int32_t idx = -1;
896900
if (node->hasAttribute(attr::warn_id)) {

torch/csrc/jit/runtime/interpreter.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
#include <torch/csrc/WindowsTorchApiMacro.h>
99
#include <torch/csrc/jit/frontend/source_range.h>
1010

11+
C10_DECLARE_bool(torch_jit_disable_warning_prints);
12+
1113
namespace at {
1214
class Tensor;
1315
CAFFE2_API void launch(std::function<void()> func);

torch/csrc/jit/runtime/profiling_graph_executor_impl.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,11 @@ C10_DEFINE_bool(
3535
true,
3636
"If this flag is set to false TorchScript will be using the legacy/original executor");
3737

38+
C10_DEFINE_bool(
39+
torch_jit_disable_warning_prints,
40+
false,
41+
"Disables warning.warn prints in TorchScript graph");
42+
3843
constexpr size_t kDefaultNumProfiledRuns = 1;
3944
constexpr size_t kDefaultBailoutDepth = 20;
4045

0 commit comments

Comments
 (0)