@@ -739,8 +739,8 @@ void checkScopeCallbacks() {
739739 std::string (fn.name ().str ()) == " test_user_scope" ) {
740740 found_user_scope = true ;
741741 }
742- },
743- []( const at::RecordFunction&) { }));
742+ return nullptr ;
743+ }));
744744
745745 bool bad_scope = false ;
746746 auto pushScopedCallback = [&](at::RecordScope scope, size_t & cnt) {
@@ -752,9 +752,8 @@ void checkScopeCallbacks() {
752752 } else {
753753 bad_scope = true ;
754754 }
755- return true ;
756- },
757- [](const at::RecordFunction&) {})
755+ return nullptr ;
756+ })
758757 .scopes ({scope}));
759758 };
760759
@@ -813,8 +812,8 @@ TEST(RecordFunctionTest, Basic) {
813812 } else if (fn.scope () == RecordScope::TORCHSCRIPT_FUNCTION) {
814813 ts_names.insert (fn.name ().str ());
815814 }
816- },
817- []( const RecordFunction&) { })
815+ return nullptr ;
816+ })
818817 .needsInputs (true ));
819818
820819 TracedTestInputs eager_inputs, jit_inputs;
@@ -851,9 +850,8 @@ TEST(RecordFunctionTest, Basic) {
851850 if (std::string (fn.name ().str ()) == " test" ) {
852851 ++sampled_cb_ctr;
853852 }
854- return true ;
855- },
856- [](const RecordFunction&) {})
853+ return nullptr ;
854+ })
857855 .samplingProb (sampling_prob));
858856 };
859857
@@ -863,9 +861,8 @@ TEST(RecordFunctionTest, Basic) {
863861 if (std::string (fn.name ().str ()) == " test" ) {
864862 ++non_sampled_cb_ctr;
865863 }
866- return true ;
867- },
868- [](const RecordFunction&) {}));
864+ return nullptr ;
865+ }));
869866
870867 auto handle = setup_sampled_callback (0.5 );
871868
@@ -908,9 +905,8 @@ TEST(RecordFunctionTest, Basic) {
908905 [&fn_names, &mtx](const RecordFunction& fn) {
909906 std::lock_guard<std::mutex> lock (mtx);
910907 fn_names.push_back (fn.name ().str ());
911- return true ;
912- },
913- [](const RecordFunction&) {}));
908+ return nullptr ;
909+ }));
914910 {
915911 RecordFunctionGuard g1 (false );
916912 {
@@ -934,8 +930,10 @@ TEST(RecordFunctionTest, Basic) {
934930 std::vector<size_t > ids;
935931 auto add_remove_test_add_cb = [&ids](size_t id) {
936932 return addGlobalCallback (RecordFunctionCallback (
937- [&ids, id](const RecordFunction& fn) { ids.push_back (id); },
938- [](const RecordFunction&) {}));
933+ [&ids, id](const RecordFunction& fn) {
934+ ids.push_back (id);
935+ return nullptr ;
936+ }));
939937 };
940938
941939 auto h1 = add_remove_test_add_cb (1 );
@@ -972,8 +970,7 @@ TEST(RecordFunctionTest, Basic) {
972970
973971 ids.clear ();
974972 addGlobalCallback (RecordFunctionCallback (
975- [&ids](const RecordFunction& fn) { ids.push_back (1 ); },
976- [](const RecordFunction&) {}));
973+ [&ids](const RecordFunction& fn) { ids.push_back (1 ); return nullptr ; }));
977974
978975 { RECORD_USER_SCOPE (" test" ); }
979976
@@ -983,8 +980,7 @@ TEST(RecordFunctionTest, Basic) {
983980
984981 auto th = std::thread ([&ids]() {
985982 addThreadLocalCallback (RecordFunctionCallback (
986- [&ids](const RecordFunction& fn) { ids.push_back (2 ); },
987- [](const RecordFunction&) {}));
983+ [&ids](const RecordFunction& fn) { ids.push_back (2 ); return nullptr ; }));
988984
989985 { RECORD_USER_SCOPE (" test_thread" ); }
990986 });
@@ -1070,8 +1066,7 @@ TEST(RecordFunctionTest, Basic) {
10701066 bool ran = false ;
10711067 should_run = false ;
10721068 addGlobalCallback (RecordFunctionCallback (
1073- [&ran](const RecordFunction& fn) { ran = true ; },
1074- [](const RecordFunction&) {})
1069+ [&ran](const RecordFunction& fn) { ran = true ; return nullptr ; })
10751070 .setShouldRun (shouldRunCallback));
10761071
10771072 { RECORD_USER_SCOPE (" test" ); }
@@ -1093,8 +1088,8 @@ TEST(RecordFunctionTest, Basic) {
10931088 auto handle = addThreadLocalCallback (RecordFunctionCallback (
10941089 [&recorded_op](const RecordFunction& fn) {
10951090 recorded_op = fn.name ().str ();
1096- },
1097- []( const RecordFunction&) { }));
1091+ return nullptr ;
1092+ }));
10981093 ThreadLocalState state;
10991094 std::thread t_child ([state]() {
11001095 ThreadLocalStateGuard g_tls (state);
@@ -1111,16 +1106,20 @@ TEST(RecordFunctionTest, Basic) {
11111106 bool has_ids = false ;
11121107 addGlobalCallback (
11131108 RecordFunctionCallback (
1114- [&has_ids](const RecordFunction& fn) { has_ids = fn.handle () > 0 ; },
1115- [](const RecordFunction&) {})
1109+ [&has_ids](const RecordFunction& fn) {
1110+ has_ids = fn.handle () > 0 ;
1111+ return nullptr ;
1112+ })
11161113 .needsIds (true ));
11171114 { RECORD_USER_SCOPE (" test" ); }
11181115 TORCH_CHECK (has_ids);
11191116 clearCallbacks ();
11201117 has_ids = false ;
11211118 addGlobalCallback (RecordFunctionCallback (
1122- [&has_ids](const RecordFunction& fn) { has_ids = fn.handle () > 0 ; },
1123- [](const RecordFunction&) {}));
1119+ [&has_ids](const RecordFunction& fn) {
1120+ has_ids = fn.handle () > 0 ;
1121+ return nullptr ;
1122+ }));
11241123 { RECORD_USER_SCOPE (" test" ); }
11251124 TORCH_CHECK (!has_ids);
11261125 clearCallbacks ();
@@ -1138,6 +1137,7 @@ TEST(RecordFunctionTest, OperatorNameOverload) {
11381137 } else {
11391138 operator_names.insert (" No Operator Name" );
11401139 }
1140+ return nullptr ;
11411141 })
11421142 .scopes ({at::RecordScope::FUNCTION}));
11431143 auto t = torch::randn ({1 , 2 , 3 }, at::kCPU );
@@ -1209,9 +1209,8 @@ TEST(ThreadLocalDebugInfoTest, Basic) {
12091209 [&done](const RecordFunction&) {
12101210 checkDebugInfo (c10::DebugInfoKind::TEST_INFO, 42 );
12111211 done = true ;
1212- return true ;
1213- },
1214- [](const RecordFunction&) {}));
1212+ return nullptr ;
1213+ }));
12151214 {
12161215 c10::DebugInfoGuard guard (c10::DebugInfoKind::TEST_INFO, debug_info);
12171216 auto t = torch::randn ({1 , 2 , 3 }, at::kCPU );
0 commit comments