Skip to content

Commit 900aa4e

Browse files
swolchokfacebook-github-bot
authored andcommitted
[PyTorch] remove convenience RecordFunctionCallback interface (#48620)
Summary: Pull Request resolved: #48620 In preparation for storing bare function pointer (8 bytes) instead of std::function (32 bytes). ghstack-source-id: 118568242 Test Plan: CI Reviewed By: ezyang Differential Revision: D25132183 fbshipit-source-id: 3790cfb5d98479a46cf665b14eb0041a872c13da
1 parent bbeee48 commit 900aa4e

File tree

6 files changed

+45
-54
lines changed

6 files changed

+45
-54
lines changed

android/pytorch_android/src/main/cpp/pytorch_jni_jit.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,13 +90,13 @@ class PytorchJni : public facebook::jni::HybridClass<PytorchJni> {
9090
#endif
9191

9292
#ifdef TRACE_ENABLED
93-
static bool onFunctionEnter(
93+
static std::unique_ptr<at::ObserverContext> onFunctionEnter(
9494
const at::RecordFunction& fn) {
9595
Trace::beginSection(fn.name().str());
96-
return true;
96+
return nullptr;
9797
}
9898

99-
static void onFunctionExit(const at::RecordFunction&) {
99+
static void onFunctionExit(const at::RecordFunction&, at::ObserverContext*) {
100100
Trace::endSection();
101101
}
102102
#endif

aten/src/ATen/record_function.h

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -316,17 +316,6 @@ class TORCH_API RecordFunctionCallback {
316316
scopes_.fill(true);
317317
}
318318

319-
// This interface is for observers that do not pass an ObserverContext object
320-
// between start and end callbacks.
321-
explicit RecordFunctionCallback(
322-
std::function<void(const RecordFunction&)> start,
323-
std::function<void(const RecordFunction&)> end =
324-
[](const RecordFunction&) {}):
325-
start_{[start](const RecordFunction& rf) { start(rf); return nullptr; }},
326-
end_{[end](const RecordFunction& rf, ObserverContext*) { end(rf); }} {
327-
scopes_.fill(true);
328-
}
329-
330319
RecordFunctionCallback& needsInputs(bool needs_inputs) {
331320
needs_inputs_ = needs_inputs;
332321
return *this;

binaries/record_function_benchmark.cc

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,11 @@ const float kLowSamplingProb = 0.0001;
1919

2020
void addTestCallback(
2121
double sampling_prob = 1.0,
22-
std::function<void(const at::RecordFunction&)> fn =
23-
[](const at::RecordFunction&) {}) {
22+
std::function<std::unique_ptr<at::ObserverContext>(const at::RecordFunction&)> fn =
23+
[](const at::RecordFunction&) { return nullptr; }) {
2424
auto cb = at::RecordFunctionCallback(
2525
std::move(fn),
26-
[](const at::RecordFunction&) {})
26+
[](const at::RecordFunction&, at::ObserverContext*) {})
2727
.needsInputs(false);
2828
if (sampling_prob < 1.0) {
2929
cb.samplingProb(sampling_prob);
@@ -111,6 +111,7 @@ int main(int argc, char** argv) {
111111
kLowSamplingProb,
112112
[&](const at::RecordFunction& fn) {
113113
++cb_count;
114+
return nullptr;
114115
}
115116
);
116117

test/cpp/jit/test_misc.cpp

Lines changed: 32 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -739,8 +739,8 @@ void checkScopeCallbacks() {
739739
std::string(fn.name().str()) == "test_user_scope") {
740740
found_user_scope = true;
741741
}
742-
},
743-
[](const at::RecordFunction&) {}));
742+
return nullptr;
743+
}));
744744

745745
bool bad_scope = false;
746746
auto pushScopedCallback = [&](at::RecordScope scope, size_t& cnt) {
@@ -752,9 +752,8 @@ void checkScopeCallbacks() {
752752
} else {
753753
bad_scope = true;
754754
}
755-
return true;
756-
},
757-
[](const at::RecordFunction&) {})
755+
return nullptr;
756+
})
758757
.scopes({scope}));
759758
};
760759

@@ -813,8 +812,8 @@ TEST(RecordFunctionTest, Basic) {
813812
} else if (fn.scope() == RecordScope::TORCHSCRIPT_FUNCTION) {
814813
ts_names.insert(fn.name().str());
815814
}
816-
},
817-
[](const RecordFunction&) {})
815+
return nullptr;
816+
})
818817
.needsInputs(true));
819818

820819
TracedTestInputs eager_inputs, jit_inputs;
@@ -851,9 +850,8 @@ TEST(RecordFunctionTest, Basic) {
851850
if (std::string(fn.name().str()) == "test") {
852851
++sampled_cb_ctr;
853852
}
854-
return true;
855-
},
856-
[](const RecordFunction&) {})
853+
return nullptr;
854+
})
857855
.samplingProb(sampling_prob));
858856
};
859857

@@ -863,9 +861,8 @@ TEST(RecordFunctionTest, Basic) {
863861
if (std::string(fn.name().str()) == "test") {
864862
++non_sampled_cb_ctr;
865863
}
866-
return true;
867-
},
868-
[](const RecordFunction&) {}));
864+
return nullptr;
865+
}));
869866

870867
auto handle = setup_sampled_callback(0.5);
871868

@@ -908,9 +905,8 @@ TEST(RecordFunctionTest, Basic) {
908905
[&fn_names, &mtx](const RecordFunction& fn) {
909906
std::lock_guard<std::mutex> lock(mtx);
910907
fn_names.push_back(fn.name().str());
911-
return true;
912-
},
913-
[](const RecordFunction&) {}));
908+
return nullptr;
909+
}));
914910
{
915911
RecordFunctionGuard g1(false);
916912
{
@@ -934,8 +930,10 @@ TEST(RecordFunctionTest, Basic) {
934930
std::vector<size_t> ids;
935931
auto add_remove_test_add_cb = [&ids](size_t id) {
936932
return addGlobalCallback(RecordFunctionCallback(
937-
[&ids, id](const RecordFunction& fn) { ids.push_back(id); },
938-
[](const RecordFunction&) {}));
933+
[&ids, id](const RecordFunction& fn) {
934+
ids.push_back(id);
935+
return nullptr ;
936+
}));
939937
};
940938

941939
auto h1 = add_remove_test_add_cb(1);
@@ -972,8 +970,7 @@ TEST(RecordFunctionTest, Basic) {
972970

973971
ids.clear();
974972
addGlobalCallback(RecordFunctionCallback(
975-
[&ids](const RecordFunction& fn) { ids.push_back(1); },
976-
[](const RecordFunction&) {}));
973+
[&ids](const RecordFunction& fn) { ids.push_back(1); return nullptr; }));
977974

978975
{ RECORD_USER_SCOPE("test"); }
979976

@@ -983,8 +980,7 @@ TEST(RecordFunctionTest, Basic) {
983980

984981
auto th = std::thread([&ids]() {
985982
addThreadLocalCallback(RecordFunctionCallback(
986-
[&ids](const RecordFunction& fn) { ids.push_back(2); },
987-
[](const RecordFunction&) {}));
983+
[&ids](const RecordFunction& fn) { ids.push_back(2); return nullptr; }));
988984

989985
{ RECORD_USER_SCOPE("test_thread"); }
990986
});
@@ -1070,8 +1066,7 @@ TEST(RecordFunctionTest, Basic) {
10701066
bool ran = false;
10711067
should_run = false;
10721068
addGlobalCallback(RecordFunctionCallback(
1073-
[&ran](const RecordFunction& fn) { ran = true; },
1074-
[](const RecordFunction&) {})
1069+
[&ran](const RecordFunction& fn) { ran = true; return nullptr; })
10751070
.setShouldRun(shouldRunCallback));
10761071

10771072
{ RECORD_USER_SCOPE("test"); }
@@ -1093,8 +1088,8 @@ TEST(RecordFunctionTest, Basic) {
10931088
auto handle = addThreadLocalCallback(RecordFunctionCallback(
10941089
[&recorded_op](const RecordFunction& fn) {
10951090
recorded_op = fn.name().str();
1096-
},
1097-
[](const RecordFunction&) {}));
1091+
return nullptr;
1092+
}));
10981093
ThreadLocalState state;
10991094
std::thread t_child([state]() {
11001095
ThreadLocalStateGuard g_tls(state);
@@ -1111,16 +1106,20 @@ TEST(RecordFunctionTest, Basic) {
11111106
bool has_ids = false;
11121107
addGlobalCallback(
11131108
RecordFunctionCallback(
1114-
[&has_ids](const RecordFunction& fn) { has_ids = fn.handle() > 0; },
1115-
[](const RecordFunction&) {})
1109+
[&has_ids](const RecordFunction& fn) {
1110+
has_ids = fn.handle() > 0;
1111+
return nullptr;
1112+
})
11161113
.needsIds(true));
11171114
{ RECORD_USER_SCOPE("test"); }
11181115
TORCH_CHECK(has_ids);
11191116
clearCallbacks();
11201117
has_ids = false;
11211118
addGlobalCallback(RecordFunctionCallback(
1122-
[&has_ids](const RecordFunction& fn) { has_ids = fn.handle() > 0; },
1123-
[](const RecordFunction&) {}));
1119+
[&has_ids](const RecordFunction& fn) {
1120+
has_ids = fn.handle() > 0;
1121+
return nullptr;
1122+
}));
11241123
{ RECORD_USER_SCOPE("test"); }
11251124
TORCH_CHECK(!has_ids);
11261125
clearCallbacks();
@@ -1138,6 +1137,7 @@ TEST(RecordFunctionTest, OperatorNameOverload) {
11381137
} else {
11391138
operator_names.insert("No Operator Name");
11401139
}
1140+
return nullptr;
11411141
})
11421142
.scopes({at::RecordScope::FUNCTION}));
11431143
auto t = torch::randn({1, 2, 3}, at::kCPU);
@@ -1209,9 +1209,8 @@ TEST(ThreadLocalDebugInfoTest, Basic) {
12091209
[&done](const RecordFunction&) {
12101210
checkDebugInfo(c10::DebugInfoKind::TEST_INFO, 42);
12111211
done = true;
1212-
return true;
1213-
},
1214-
[](const RecordFunction&) {}));
1212+
return nullptr;
1213+
}));
12151214
{
12161215
c10::DebugInfoGuard guard(c10::DebugInfoKind::TEST_INFO, debug_info);
12171216
auto t = torch::randn({1, 2, 3}, at::kCPU);

torch/csrc/autograd/init.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,8 +173,8 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject *unused) {
173173
});
174174
m.def("_set_empty_test_observer", [](bool is_global, double sampling_prob) {
175175
auto cb = at::RecordFunctionCallback(
176-
[](const at::RecordFunction&) {},
177-
[](const at::RecordFunction&) {})
176+
[](const at::RecordFunction&) { return nullptr; },
177+
[](const at::RecordFunction&, at::ObserverContext*) {})
178178
.needsInputs(true)
179179
.samplingProb(sampling_prob);
180180
if (is_global) {

torch/csrc/autograd/profiler_legacy.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -417,7 +417,7 @@ void pushProfilingCallbacksLegacy() {
417417
[](const at::RecordFunction& fn) {
418418
auto state_ptr = getProfilerTLSState();
419419
if (!state_ptr || state_ptr->config().state == ProfilerState::Disabled) {
420-
return;
420+
return nullptr;
421421
}
422422
bool record_cuda =
423423
state_ptr->config().state == ProfilerState::CUDA;
@@ -432,8 +432,10 @@ void pushProfilingCallbacksLegacy() {
432432
} else {
433433
state_ptr->pushRange(fn, record_cuda, msg);
434434
}
435+
436+
return nullptr;
435437
},
436-
[](const at::RecordFunction& fn) {
438+
[](const at::RecordFunction& fn, at::ObserverContext*) {
437439
auto state_ptr = getProfilerTLSState();
438440
if (!state_ptr || state_ptr->config().state == ProfilerState::Disabled) {
439441
return;

0 commit comments

Comments
 (0)