Skip to content

Commit 093d841

Browse files
Add timeout support when waiting on servables to load.
PiperOrigin-RevId: 635590139
1 parent 31ec013 commit 093d841

File tree

6 files changed

+35
-8
lines changed

6 files changed

+35
-8
lines changed

tensorflow_serving/core/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -417,6 +417,7 @@ cc_library(
417417
":servable_id",
418418
":servable_state",
419419
"//tensorflow_serving/util:event_bus",
420+
"@com_google_absl//absl/time",
420421
"@com_google_absl//absl/types:optional",
421422
"@org_tensorflow//tensorflow/core:lib",
422423
],

tensorflow_serving/core/servable_state_monitor.cc

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ limitations under the License.
1919
#include <utility>
2020
#include <vector>
2121

22+
#include "absl/time/time.h"
2223
#include "tensorflow/core/lib/core/notification.h"
2324
#include "tensorflow/core/lib/gtl/cleanup.h"
2425
#include "tensorflow_serving/core/servable_state.h"
@@ -234,11 +235,11 @@ void ServableStateMonitor::Notify(const NotifyFn& notify_fn) {
234235
notify_fns_.push_back(notify_fn);
235236
}
236237

237-
bool ServableStateMonitor::WaitUntilServablesReachState(
238+
bool ServableStateMonitor::WaitUntilServablesReachStateWithTimeout(
238239
const std::vector<ServableRequest>& servables,
239-
const ServableState::ManagerState goal_state,
240+
const ServableState::ManagerState goal_state, absl::Duration timeout,
240241
std::map<ServableId, ServableState::ManagerState>* const states_reached) {
241-
bool reached_goal_state;
242+
bool reached_goal_state = false;
242243
Notification notified;
243244
NotifyWhenServablesReachState(
244245
servables, goal_state,
@@ -251,10 +252,19 @@ bool ServableStateMonitor::WaitUntilServablesReachState(
251252
reached_goal_state = incoming_reached_goal_state;
252253
notified.Notify();
253254
});
254-
notified.WaitForNotification();
255+
notified.WaitForNotificationWithTimeout(timeout);
255256
return reached_goal_state;
256257
}
257258

259+
bool ServableStateMonitor::WaitUntilServablesReachState(
260+
const std::vector<ServableRequest>& servables,
261+
const ServableState::ManagerState goal_state,
262+
std::map<ServableId, ServableState::ManagerState>* const states_reached) {
263+
return WaitUntilServablesReachStateWithTimeout(
264+
servables, goal_state,
265+
/*timeout=*/absl::InfiniteDuration(), states_reached);
266+
}
267+
258268
void ServableStateMonitor::PreHandleEvent(
259269
const EventBus<ServableState>::EventAndTime& state_and_time) {}
260270

tensorflow_serving/core/servable_state_monitor.h

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ limitations under the License.
2020
#include <functional>
2121
#include <map>
2222

23+
#include "absl/time/time.h"
2324
#include "absl/types/optional.h"
2425
#include "tensorflow/core/platform/env.h"
2526
#include "tensorflow/core/platform/macros.h"
@@ -156,11 +157,19 @@ class ServableStateMonitor {
156157
///
157158
/// To understand the return value and the return parameter 'states_reached',
158159
/// please read the documentation on NotifyWhenServablesReachState(...).
160+
/// WaitUntilServablesReachStateWithTimeout and WaitUntilServablesReachState
161+
/// perform the same function, but the former has a timeout while the latter
162+
/// waits indefinitely.
163+
bool WaitUntilServablesReachStateWithTimeout(
164+
const std::vector<ServableRequest>& servables,
165+
ServableState::ManagerState goal_state, absl::Duration timeout,
166+
std::map<ServableId, ServableState::ManagerState>* states_reached =
167+
nullptr) TF_LOCKS_EXCLUDED(mu_) TF_MUST_USE_RESULT;
159168
bool WaitUntilServablesReachState(
160169
const std::vector<ServableRequest>& servables,
161170
ServableState::ManagerState goal_state,
162171
std::map<ServableId, ServableState::ManagerState>* states_reached =
163-
nullptr) TF_LOCKS_EXCLUDED(mu_) TF_MUST_USE_RESULT;
172+
nullptr) TF_MUST_USE_RESULT;
164173

165174
// Subscribes to all servable state changes hitting this monitor. This is
166175
// called after the monitor updates its own state based on the event.

tensorflow_serving/model_servers/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ cc_library(
9696
"//tensorflow_serving/util:event_bus",
9797
"//tensorflow_serving/util:unique_ptr_with_deps",
9898
"@com_google_absl//absl/base:core_headers",
99+
"@com_google_absl//absl/time",
99100
"@com_google_absl//absl/types:optional",
100101
"@com_google_protobuf//:cc_wkt_protos",
101102
"@org_tensorflow//tensorflow/core:lib",

tensorflow_serving/model_servers/server_core.cc

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ limitations under the License.
3030
#include "tensorflow/core/platform/logging.h"
3131
#include "tensorflow_serving/config/file_system_storage_path_source.pb.h"
3232
#include "tensorflow_serving/core/load_servables_fast.h"
33+
#include "tensorflow_serving/core/servable_state_monitor.h"
3334
#include "tensorflow_serving/model_servers/model_platform_types.h"
3435
#include "tensorflow_serving/resources/resource_values.h"
3536
#include "tensorflow_serving/servables/tensorflow/saved_model_bundle_source_adapter.h"
@@ -296,9 +297,10 @@ Status ServerCore::WaitUntilModelsAvailable(const std::set<string>& models,
296297
awaited_servables.push_back(ServableRequest::Latest(model));
297298
}
298299
std::map<ServableId, ServableState::ManagerState> states_reached;
299-
const bool all_models_available = monitor->WaitUntilServablesReachState(
300-
awaited_servables, ServableState::ManagerState::kAvailable,
301-
&states_reached);
300+
const bool all_models_available =
301+
monitor->WaitUntilServablesReachStateWithTimeout(
302+
awaited_servables, ServableState::ManagerState::kAvailable,
303+
options_.servable_state_waiter_timeout, &states_reached);
302304
if (!all_models_available) {
303305
const int num_unavailable_models = std::count_if(
304306
states_reached.begin(), states_reached.end(),
@@ -367,6 +369,7 @@ Status ServerCore::AddModelsViaModelConfigList() {
367369
} else {
368370
// Create a fresh servable state monitor, to avoid getting confused if we're
369371
// re-loading a model-version that has previously been unloaded.
372+
370373
ServableStateMonitor fresh_servable_state_monitor(
371374
servable_event_bus_.get());
372375

tensorflow_serving/model_servers/server_core.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ limitations under the License.
2424

2525
#include "google/protobuf/any.pb.h"
2626
#include "absl/base/macros.h"
27+
#include "absl/time/time.h"
2728
#include "absl/types/optional.h"
2829
#include "tensorflow/core/lib/core/status.h"
2930
#include "tensorflow/core/platform/cpu_info.h"
@@ -207,6 +208,8 @@ class ServerCore : public Manager {
207208
// If true, propagate current context to children threads (periodic
208209
// functions) in AspiredVersionsManager.
209210
bool with_current_context = false;
211+
212+
absl::Duration servable_state_waiter_timeout = absl::InfiniteDuration();
210213
};
211214

212215
virtual ~ServerCore() = default;

0 commit comments

Comments
 (0)