Skip to content

Commit b50c834

Browse files
committed
Fix reliability issues in LogAllSessions. (#22568)
Issue can happen with multiple sessions and when ETW captureState / rundown is triggered. Resolves use after free issue. Tested with local unit test creating/destroying multiple sessions while continually enabling & disabling ETW. This currently requires Admin prompt so not checking in ORT should not crash
1 parent b76588b commit b50c834

File tree

2 files changed

+60
-26
lines changed

2 files changed

+60
-26
lines changed

onnxruntime/core/session/inference_session.cc

Lines changed: 58 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ Status GetMinimalBuildOptimizationHandling(
249249
std::atomic<uint32_t> InferenceSession::global_session_id_{1};
250250
std::map<uint32_t, InferenceSession*> InferenceSession::active_sessions_;
251251
#ifdef _WIN32
252-
OrtMutex InferenceSession::active_sessions_mutex_; // Protects access to active_sessions_
252+
std::mutex InferenceSession::active_sessions_mutex_; // Protects access to active_sessions_
253253
onnxruntime::WindowsTelemetry::EtwInternalCallback InferenceSession::callback_ML_ORT_provider_;
254254
#endif
255255

@@ -371,7 +371,7 @@ void InferenceSession::ConstructorCommon(const SessionOptions& session_options,
371371
session_id_ = global_session_id_.fetch_add(1);
372372

373373
#ifdef _WIN32
374-
std::lock_guard<OrtMutex> lock(active_sessions_mutex_);
374+
std::lock_guard<std::mutex> lock(active_sessions_mutex_);
375375
active_sessions_[global_session_id_++] = this;
376376

377377
// Register callback for ETW capture state (rundown) for Microsoft.ML.ONNXRuntime provider
@@ -725,13 +725,9 @@ InferenceSession::~InferenceSession() {
725725

726726
// Unregister the session and ETW callbacks
727727
#ifdef _WIN32
728-
std::lock_guard<OrtMutex> lock(active_sessions_mutex_);
729-
if (callback_ML_ORT_provider_ != nullptr) {
730-
WindowsTelemetry::UnregisterInternalCallback(callback_ML_ORT_provider_);
731-
}
732-
if (callback_ETWSink_provider_ != nullptr) {
733-
logging::EtwRegistrationManager::Instance().UnregisterInternalCallback(callback_ETWSink_provider_);
734-
}
728+
std::lock_guard<std::mutex> lock(active_sessions_mutex_);
729+
WindowsTelemetry::UnregisterInternalCallback(callback_ML_ORT_provider_);
730+
logging::EtwRegistrationManager::Instance().UnregisterInternalCallback(callback_ETWSink_provider_);
735731
#endif
736732
active_sessions_.erase(global_session_id_);
737733

@@ -749,7 +745,7 @@ common::Status InferenceSession::RegisterExecutionProvider(const std::shared_ptr
749745
return Status(common::ONNXRUNTIME, common::FAIL, "Received nullptr for exec provider");
750746
}
751747

752-
std::lock_guard<onnxruntime::OrtMutex> l(session_mutex_);
748+
std::lock_guard<std::mutex> l(session_mutex_);
753749

754750
if (is_inited_) {
755751
// adding an EP is pointless as the graph as already been partitioned so no nodes will be assigned to
@@ -880,7 +876,7 @@ common::Status InferenceSession::RegisterGraphTransformer(
880876
return Status(common::ONNXRUNTIME, common::FAIL, "Received nullptr for graph transformer");
881877
}
882878

883-
std::lock_guard<onnxruntime::OrtMutex> l(session_mutex_);
879+
std::lock_guard<std::mutex> l(session_mutex_);
884880

885881
if (is_inited_) {
886882
// adding a transformer now is pointless as the graph as already been transformed
@@ -944,7 +940,7 @@ common::Status InferenceSession::LoadWithLoader(std::function<common::Status(std
944940
tp = session_profiler_.Start();
945941
}
946942
ORT_TRY {
947-
std::lock_guard<onnxruntime::OrtMutex> l(session_mutex_);
943+
std::lock_guard<std::mutex> l(session_mutex_);
948944
if (is_model_loaded_) { // already loaded
949945
LOGS(*session_logger_, ERROR) << "This session already contains a loaded model.";
950946
return common::Status(common::ONNXRUNTIME, common::MODEL_LOADED, "This session already contains a loaded model.");
@@ -1400,7 +1396,7 @@ Status InferenceSession::LoadOrtModel(const void* model_data, int model_data_len
14001396
}
14011397

14021398
Status InferenceSession::LoadOrtModelWithLoader(std::function<Status()> load_ort_format_model_bytes) {
1403-
std::lock_guard<onnxruntime::OrtMutex> l(session_mutex_);
1399+
std::lock_guard<std::mutex> l(session_mutex_);
14041400

14051401
if (is_model_loaded_) { // already loaded
14061402
Status status(common::ONNXRUNTIME, common::MODEL_LOADED, "This session already contains a loaded model.");
@@ -1524,7 +1520,7 @@ Status InferenceSession::LoadOrtModelWithLoader(std::function<Status()> load_ort
15241520
}
15251521

15261522
bool InferenceSession::IsInitialized() const {
1527-
std::lock_guard<onnxruntime::OrtMutex> l(session_mutex_);
1523+
std::lock_guard<std::mutex> l(session_mutex_);
15281524
return is_inited_;
15291525
}
15301526

@@ -1677,7 +1673,7 @@ common::Status InferenceSession::Initialize() {
16771673
bool have_cpu_ep = false;
16781674

16791675
{
1680-
std::lock_guard<onnxruntime::OrtMutex> initial_guard(session_mutex_);
1676+
std::lock_guard<std::mutex> initial_guard(session_mutex_);
16811677

16821678
if (!is_model_loaded_) {
16831679
LOGS(*session_logger_, ERROR) << "Model was not loaded";
@@ -1715,7 +1711,7 @@ common::Status InferenceSession::Initialize() {
17151711
}
17161712

17171713
// re-acquire mutex
1718-
std::lock_guard<onnxruntime::OrtMutex> l(session_mutex_);
1714+
std::lock_guard<std::mutex> l(session_mutex_);
17191715

17201716
#if !defined(DISABLE_EXTERNAL_INITIALIZERS) && !defined(ORT_MINIMAL_BUILD)
17211717
if (!session_options_.external_initializers.empty()) {
@@ -2031,9 +2027,11 @@ common::Status InferenceSession::Initialize() {
20312027
#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
20322028
}
20332029

2030+
SessionState::PrePackInitializers pre_packed_initializers;
20342031
ORT_RETURN_IF_ERROR_SESSIONID_(
20352032
session_state_->FinalizeSessionState(model_location_, kernel_registry_manager_,
20362033
// need to keep the initializers if saving the optimized model
2034+
pre_packed_initializers,
20372035
!saving_model,
20382036
saving_ort_format));
20392037

@@ -2069,11 +2067,47 @@ common::Status InferenceSession::Initialize() {
20692067
kOrtSessionOptionsOptimizedModelExternalInitializersMinSizeInBytes, "1024"));
20702068
Graph::OffsetAlignmentInfo align_info;
20712069
align_info.align_offset = true;
2070+
bool save_prepacked_constant_initializers =
2071+
session_options_.config_options.GetConfigOrDefault(kOrtSessionOptionsSavePrePackedConstantInitializers, "0") == "1" ? true : false;
2072+
Graph::PrePackedTensorProtoToSave pre_packed_initializers_tensor_proto;
2073+
if (save_prepacked_constant_initializers) {
2074+
LOGS(*session_logger_, WARNING) << "Serialize prepacked initializers option has been turn on."
2075+
<< "Use this option only when run model inference on PC with CPU."
2076+
<< "Make sure to save and load model in same device as prepack is device specific."
2077+
<< "Note: this feature in only work with ONNX model format."
2078+
<< "Process of use this option is like below:"
2079+
<< "1. Optimize model with external data file with save_prepacked_constant_initializers on:"
2080+
<< " sample: sess_options.add_session_config_entry('session.save_prepacked_constant_initializers', ' 1 ')"
2081+
<< " With save_prepacked_constant_initializers option, prepacked initializer will be serialized into data file."
2082+
<< "2. Load optimized model and external data file in same device, no prepack is need."
2083+
<< "3. Run inference with optimized model.";
2084+
2085+
if (fbs::utils::IsOrtFormatModel(session_options_.optimized_model_filepath)) {
2086+
ORT_RETURN_IF_ERROR_SESSIONID_(
2087+
ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
2088+
"Unable to serialize prepacked external constant initializer for ORT format model."
2089+
"Please use ONNX format model with save_prepacked_constant_initializers."));
2090+
}
2091+
2092+
// convert pre_packed_initializers to tensorproto format and save to external data file
2093+
for (const auto& name_item_pair : pre_packed_initializers.pre_packed_initializers_to_save) {
2094+
auto initializer_name = name_item_pair.first;
2095+
2096+
for (const auto& kernel_name_initializer_item_pair : name_item_pair.second) {
2097+
auto kernel_name = kernel_name_initializer_item_pair.first;
2098+
auto prepacked_initializer_name = utils::GetPrepackedInitializerName(initializer_name, kernel_name);
2099+
2100+
pre_packed_initializers_tensor_proto[initializer_name][kernel_name] = utils::TensorToTensorProto(kernel_name_initializer_item_pair.second, prepacked_initializer_name);
2101+
}
2102+
}
2103+
}
20722104
ORT_RETURN_IF_ERROR_SESSIONID_(Model::SaveWithExternalInitializers(*model_,
20732105
session_options_.optimized_model_filepath,
20742106
optimized_model_external_initializers_file_name,
20752107
optimized_model_external_initializers_min_size_in_bytes,
2076-
align_info));
2108+
align_info,
2109+
save_prepacked_constant_initializers,
2110+
pre_packed_initializers_tensor_proto));
20772111
}
20782112
}
20792113
}
@@ -2588,7 +2622,7 @@ Status InferenceSession::Run(const RunOptions& run_options,
25882622
std::unique_ptr<logging::Logger> owned_run_logger;
25892623
const auto& run_logger = CreateLoggerForRun(run_options, owned_run_logger);
25902624

2591-
std::optional<std::lock_guard<OrtMutex>> sequential_run_lock;
2625+
std::optional<std::lock_guard<std::mutex>> sequential_run_lock;
25922626
if (is_concurrent_run_supported_ == false) {
25932627
sequential_run_lock.emplace(session_mutex_);
25942628
}
@@ -2841,7 +2875,7 @@ common::Status InferenceSession::Run(const RunOptions& run_options, const NameML
28412875

28422876
std::pair<common::Status, const ModelMetadata*> InferenceSession::GetModelMetadata() const {
28432877
{
2844-
std::lock_guard<onnxruntime::OrtMutex> l(session_mutex_);
2878+
std::lock_guard<std::mutex> l(session_mutex_);
28452879
if (!is_model_loaded_) {
28462880
LOGS(*session_logger_, ERROR) << "Model was not loaded";
28472881
return std::make_pair(common::Status(common::ONNXRUNTIME, common::FAIL, "Model was not loaded."), nullptr);
@@ -2853,7 +2887,7 @@ std::pair<common::Status, const ModelMetadata*> InferenceSession::GetModelMetada
28532887

28542888
std::pair<common::Status, const InputDefList*> InferenceSession::GetModelInputs() const {
28552889
{
2856-
std::lock_guard<onnxruntime::OrtMutex> l(session_mutex_);
2890+
std::lock_guard<std::mutex> l(session_mutex_);
28572891
if (!is_model_loaded_) {
28582892
LOGS(*session_logger_, ERROR) << "Model was not loaded";
28592893
return std::make_pair(common::Status(common::ONNXRUNTIME, common::FAIL, "Model was not loaded."), nullptr);
@@ -2866,7 +2900,7 @@ std::pair<common::Status, const InputDefList*> InferenceSession::GetModelInputs(
28662900

28672901
std::pair<common::Status, const InputDefList*> InferenceSession::GetOverridableInitializers() const {
28682902
{
2869-
std::lock_guard<onnxruntime::OrtMutex> l(session_mutex_);
2903+
std::lock_guard<std::mutex> l(session_mutex_);
28702904
if (!is_model_loaded_) {
28712905
LOGS(*session_logger_, ERROR) << "Model was not loaded";
28722906
return std::make_pair(common::Status(common::ONNXRUNTIME, common::FAIL, "Model was not loaded."), nullptr);
@@ -2879,7 +2913,7 @@ std::pair<common::Status, const InputDefList*> InferenceSession::GetOverridableI
28792913

28802914
std::pair<common::Status, const OutputDefList*> InferenceSession::GetModelOutputs() const {
28812915
{
2882-
std::lock_guard<onnxruntime::OrtMutex> l(session_mutex_);
2916+
std::lock_guard<std::mutex> l(session_mutex_);
28832917
if (!is_model_loaded_) {
28842918
LOGS(*session_logger_, ERROR) << "Model was not loaded";
28852919
return std::make_pair(common::Status(common::ONNXRUNTIME, common::FAIL, "Model was not loaded."), nullptr);
@@ -2891,7 +2925,7 @@ std::pair<common::Status, const OutputDefList*> InferenceSession::GetModelOutput
28912925

28922926
common::Status InferenceSession::NewIOBinding(std::unique_ptr<IOBinding>* io_binding) {
28932927
{
2894-
std::lock_guard<onnxruntime::OrtMutex> l(session_mutex_);
2928+
std::lock_guard<std::mutex> l(session_mutex_);
28952929
if (!is_inited_) {
28962930
LOGS(*session_logger_, ERROR) << "Session was not initialized";
28972931
return common::Status(common::ONNXRUNTIME, common::FAIL, "Session not initialized.");
@@ -3275,7 +3309,7 @@ IOBinding* SessionIOBinding::Get() {
32753309
void InferenceSession::LogAllSessions() {
32763310
const Env& env = Env::Default();
32773311

3278-
std::lock_guard<OrtMutex> lock(active_sessions_mutex_);
3312+
std::lock_guard<std::mutex> lock(active_sessions_mutex_);
32793313
for (const auto& session_pair : active_sessions_) {
32803314
InferenceSession* session = session_pair.second;
32813315

onnxruntime/core/session/inference_session.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -663,7 +663,7 @@ class InferenceSession {
663663

664664
void InitLogger(logging::LoggingManager* logging_manager);
665665

666-
void TraceSessionOptions(const SessionOptions& session_options, bool captureState);
666+
static void TraceSessionOptions(const SessionOptions& session_options, bool captureState, const logging::Logger& logger);
667667

668668
[[nodiscard]] common::Status CheckShapes(const std::string& input_name, const TensorShape& input_shape,
669669
const TensorShape& expected_shape, const char* input_output_moniker) const;
@@ -700,7 +700,7 @@ class InferenceSession {
700700
void ShrinkMemoryArenas(gsl::span<const AllocatorPtr> arenas_to_shrink);
701701

702702
#ifdef _WIN32
703-
void LogAllSessions();
703+
static void LogAllSessions();
704704
#endif
705705

706706
#if !defined(ORT_MINIMAL_BUILD)

0 commit comments

Comments
 (0)