Skip to content

Commit 9ba72fa

Browse files
Add option to stop retrying on permanent loading errors.
PiperOrigin-RevId: 657375849
1 parent 6e05a38 commit 9ba72fa

16 files changed

+217
-70
lines changed

tensorflow_serving/core/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -447,6 +447,8 @@ cc_library(
447447
":loader",
448448
":servable_id",
449449
"//tensorflow_serving/util:retrier",
450+
"@com_google_absl//absl/log",
451+
"@com_google_absl//absl/status",
450452
"@com_google_absl//absl/types:optional",
451453
"@org_tensorflow//tensorflow/core:lib",
452454
],

tensorflow_serving/core/aspired_versions_manager.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,10 @@ Status AspiredVersionsManager::Create(
166166
basic_manager_options.env = options.env;
167167
basic_manager_options.servable_event_bus = options.servable_event_bus;
168168
basic_manager_options.pre_load_hook = std::move(options.pre_load_hook);
169+
if (options.should_retry_model_load) {
170+
basic_manager_options.should_retry_model_load =
171+
std::move(options.should_retry_model_load);
172+
}
169173
std::unique_ptr<BasicManager> basic_manager;
170174
TF_RETURN_IF_ERROR(
171175
BasicManager::Create(std::move(basic_manager_options), &basic_manager));

tensorflow_serving/core/aspired_versions_manager.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,9 @@ class AspiredVersionsManager : public Manager,
127127
/// Default: 1 minute.
128128
int64_t load_retry_interval_micros = 1LL * 60 * 1000 * 1000;
129129

130+
// Defines how we want to retry when model loading fails.
131+
std::function<bool(absl::Status)> should_retry_model_load;
132+
130133
// If true, and there are not multiple load threads, filesystem caches will
131134
// be flushed after each servable is loaded. (Cache flush is skipped when
132135
// multiple load threads are active, in order to avoid setting back a

tensorflow_serving/core/basic_manager.cc

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,16 @@ limitations under the License.
2424
#include <utility>
2525
#include <vector>
2626

27+
#include "absl/status/status.h"
2728
#include "tensorflow/core/lib/core/errors.h"
2829
#include "tensorflow/core/lib/strings/strcat.h"
2930
#include "tensorflow/core/platform/logging.h"
3031
#include "tensorflow/core/platform/macros.h"
3132
#include "tensorflow_serving/core/servable_handle.h"
33+
#include "tensorflow_serving/core/servable_state.h"
3234
#include "tensorflow_serving/core/source.h"
35+
#include "tensorflow_serving/resources/resource_tracker.h"
36+
#include "tensorflow_serving/util/event_bus.h"
3337
#include "tensorflow_serving/util/hash.h"
3438
#include "tensorflow_serving/util/inline_executor.h"
3539
#include "tensorflow_serving/util/retrier.h"
@@ -225,21 +229,23 @@ Status BasicManager::Create(Options options,
225229
std::unique_ptr<BasicManager>* manager) {
226230
manager->reset(new BasicManager(
227231
options.env, options.num_load_threads, options.num_unload_threads,
228-
options.max_num_load_retries, options.load_retry_interval_micros,
229-
options.flush_filesystem_caches, std::move(options.resource_tracker),
230-
options.servable_event_bus, std::move(options.pre_load_hook)));
232+
options.max_num_load_retries, std::move(options.should_retry_model_load),
233+
options.load_retry_interval_micros, options.flush_filesystem_caches,
234+
std::move(options.resource_tracker), options.servable_event_bus,
235+
std::move(options.pre_load_hook)));
231236
return OkStatus();
232237
}
233238

234-
BasicManager::BasicManager(Env* const env, const uint32 num_load_threads,
235-
const uint32 num_unload_threads,
236-
uint32 max_num_load_retries,
237-
int64_t load_retry_interval_micros,
238-
bool flush_filesystem_caches,
239-
std::unique_ptr<ResourceTracker> resource_tracker,
240-
EventBus<ServableState>* servable_event_bus,
241-
std::function<void(const ServableId&)> pre_load_hook)
239+
BasicManager::BasicManager(
240+
Env* const env, const uint32 num_load_threads,
241+
const uint32 num_unload_threads, uint32 max_num_load_retries,
242+
std::function<bool(absl::Status)> should_retry_model_load,
243+
int64_t load_retry_interval_micros, bool flush_filesystem_caches,
244+
std::unique_ptr<ResourceTracker> resource_tracker,
245+
EventBus<ServableState>* servable_event_bus,
246+
std::function<void(const ServableId&)> pre_load_hook)
242247
: servable_event_bus_(servable_event_bus),
248+
should_retry_model_load_(std::move(should_retry_model_load)),
243249
env_(env),
244250
num_load_threads_(num_load_threads),
245251
flush_filesystem_caches_(flush_filesystem_caches),
@@ -357,6 +363,9 @@ Status BasicManager::ManageServableInternal(
357363

358364
std::shared_ptr<LoaderHarness> harness =
359365
harness_creator(servable.id(), std::move(loader));
366+
if (should_retry_model_load_) {
367+
harness->set_should_retry(should_retry_model_load_);
368+
}
360369
if (!servable.status().ok()) {
361370
harness->Error(servable.status());
362371
} else {
@@ -527,7 +536,7 @@ void BasicManager::CancelLoadServableRetry(const ServableId& id) {
527536
if (!status.ok()) {
528537
return;
529538
}
530-
harness->set_cancel_load_retry(true);
539+
harness->set_should_retry([](absl::Status status) { return false; });
531540
}
532541

533542
Status BasicManager::ExecuteUnload(LoaderHarness* harness) {
@@ -748,7 +757,7 @@ Status BasicManager::ReserveResources(LoaderHarness* harness,
748757
return resource_tracker_->ReserveResources(*harness->loader(),
749758
&resources_reserved);
750759
},
751-
[&]() { return harness->cancel_load_retry(); });
760+
[&](absl::Status status) { return harness->should_retry(status); });
752761
if (!reserve_resources_status.ok()) {
753762
return Status(
754763
reserve_resources_status.code(),

tensorflow_serving/core/basic_manager.h

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,9 @@ class BasicManager : public Manager {
125125
// If set as 0, we don't use a thread-pool, and UnloadServable() blocks.
126126
uint32 num_unload_threads = 0;
127127

128+
// Defines how we want to retry when model loading fails.
129+
std::function<bool(absl::Status)> should_retry_model_load;
130+
128131
// EventBus to publish servable state changes. This is optional, if unset,
129132
// we don't publish.
130133
EventBus<ServableState>* servable_event_bus = nullptr;
@@ -242,8 +245,9 @@ class BasicManager : public Manager {
242245
/// succeed and the rest will fail with an error status.
243246
void LoadServable(const ServableId& id, DoneCallback done_callback);
244247

245-
/// Cancels retrying the servable load during LoadServable(). Does nothing if
246-
/// the servable isn't managed.
248+
/// Cancels retrying the servable load during LoadServable() by replacing the
249+
/// LoaderHarness::should_retry with a function that always returns false.
250+
/// Does nothing if the servable isn't managed.
247251
///
248252
/// If the retries are cancelled, the servable goes into a state dependent on
249253
/// the last Load() called on it. If the last Load() was successful, it will
@@ -269,8 +273,9 @@ class BasicManager : public Manager {
269273
friend class test_util::BasicManagerTestAccess;
270274

271275
BasicManager(Env* env, uint32 num_load_threads, uint32 num_unload_threads,
272-
uint32 max_num_load_retries, int64_t load_retry_interval_micros,
273-
bool flush_filesystem_caches,
276+
uint32 max_num_load_retries,
277+
std::function<bool(absl::Status)> should_retry_model_load,
278+
int64_t load_retry_interval_micros, bool flush_filesystem_caches,
274279
std::unique_ptr<ResourceTracker> resource_tracker,
275280
EventBus<ServableState>* servable_event_bus,
276281
PreLoadHook pre_load_hook);
@@ -418,6 +423,9 @@ class BasicManager : public Manager {
418423
// if no bus has been configured.
419424
EventBus<ServableState>* servable_event_bus_;
420425

426+
// Defines how we want to retry when model loading fails.
427+
std::function<bool(absl::Status)> should_retry_model_load_;
428+
421429
// Used to protect access to 'managed_map_', 'resource_tracker_' and other
422430
// core state elements.
423431
mutable mutex mu_;

tensorflow_serving/core/basic_manager_test.cc

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,13 @@ limitations under the License.
2626

2727
#include <gmock/gmock.h>
2828
#include <gtest/gtest.h>
29+
#include "absl/status/status.h"
2930
#include "absl/types/optional.h"
3031
#include "tensorflow/core/lib/core/errors.h"
3132
#include "tensorflow/core/lib/core/status_test_util.h"
3233
#include "tensorflow/core/lib/strings/strcat.h"
3334
#include "tensorflow/core/platform/blocking_counter.h"
35+
#include "tensorflow/core/platform/errors.h"
3436
#include "tensorflow/core/platform/null_file_system.h"
3537
#include "tensorflow/core/protobuf/error_codes.pb.h"
3638
#include "tensorflow_serving/core/servable_state_monitor.h"
@@ -1689,12 +1691,51 @@ TEST(EstimateResourcesRetriedTest, Fails) {
16891691
id, [](const Status& status) { EXPECT_FALSE(status.ok()); });
16901692
WaitUntilServableManagerStateIsOneOf(servable_state_monitor, id,
16911693
{ServableState::ManagerState::kEnd});
1692-
const ServableState available_state = {
1693-
id, ServableState::ManagerState::kEnd,
1694-
errors::Internal("Error on estimate resources.")};
16951694
EXPECT_FALSE(servable_state_monitor.GetState(id)->health.ok());
16961695
}
16971696

1697+
TEST(EstimateResourcesRetriedTest, NonRetriableError) {
1698+
std::shared_ptr<EventBus<ServableState>> servable_event_bus =
1699+
EventBus<ServableState>::CreateEventBus();
1700+
ServableStateMonitor servable_state_monitor(servable_event_bus.get());
1701+
1702+
BasicManager::Options options;
1703+
// Seed the manager with ten resource units.
1704+
options.resource_tracker = CreateSimpleResourceTracker(10);
1705+
options.servable_event_bus = servable_event_bus.get();
1706+
options.num_load_threads = 0;
1707+
options.num_unload_threads = 0;
1708+
options.should_retry_model_load =
1709+
([](absl::Status status) { return !absl::IsInvalidArgument(status); });
1710+
1711+
options.max_num_load_retries = 10;
1712+
options.load_retry_interval_micros = 100000000;
1713+
1714+
std::unique_ptr<BasicManager> basic_manager;
1715+
TF_CHECK_OK(BasicManager::Create(std::move(options), &basic_manager));
1716+
1717+
const ServableId id = {kServableName, 7};
1718+
test_util::MockLoader* loader = new NiceMock<test_util::MockLoader>;
1719+
EXPECT_CALL(*loader, LoadWithMetadata(_))
1720+
.WillOnce(Return(errors::InvalidArgument("Non-retriable error.")))
1721+
.WillRepeatedly(Return(absl::OkStatus()));
1722+
TF_ASSERT_OK(basic_manager->ManageServable(
1723+
CreateServableData(id, std::unique_ptr<Loader>(loader))));
1724+
basic_manager->LoadServable(
1725+
id, [](const auto& status) { EXPECT_FALSE(status.ok()); });
1726+
1727+
// Make sure the final state is kEnd.
1728+
WaitUntilServableManagerStateIsOneOf(
1729+
servable_state_monitor, id,
1730+
{ServableState::ManagerState::kEnd,
1731+
ServableState::ManagerState::kAvailable});
1732+
const auto final_state = servable_state_monitor.GetState(id);
1733+
ASSERT_TRUE(final_state.has_value());
1734+
EXPECT_EQ(final_state->manager_state, ServableState::ManagerState::kEnd);
1735+
EXPECT_FALSE(final_state->health.ok());
1736+
EXPECT_EQ(final_state->health.message(), "Non-retriable error.");
1737+
}
1738+
16981739
} // namespace
16991740
} // namespace serving
17001741
} // namespace tensorflow

tensorflow_serving/core/loader_harness.cc

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,12 @@ limitations under the License.
1515

1616
#include "tensorflow_serving/core/loader_harness.h"
1717

18-
#include <algorithm>
18+
#include <functional>
1919
#include <memory>
2020
#include <utility>
2121

22+
#include "absl/log/log.h"
23+
#include "absl/status/status.h"
2224
#include "tensorflow/core/lib/core/errors.h"
2325
#include "tensorflow/core/lib/strings/strcat.h"
2426
#include "tensorflow/core/platform/env.h"
@@ -33,7 +35,8 @@ LoaderHarness::LoaderHarness(const ServableId& id,
3335
: id_(id),
3436
loader_(std::move(loader)),
3537
additional_state_(nullptr),
36-
options_(options) {
38+
options_(options),
39+
should_retry_([&](absl::Status status) { return true; }) {
3740
VLOG(1) << "Starting to manage servable version " << id_;
3841
}
3942

@@ -80,13 +83,13 @@ Status LoaderHarness::Load() {
8083
strings::StrCat("Loading servable: ", id_.DebugString()),
8184
options_.max_num_load_retries, options_.load_retry_interval_micros,
8285
[&]() { return loader_->LoadWithMetadata({id_}); },
83-
[&]() { return cancel_load_retry(); });
86+
[&](absl::Status status) { return should_retry(status); });
8487

8588
if (status.ok()) {
86-
if (cancel_load_retry()) {
87-
// Servable is going to be unloaded very soon,
88-
// we report a failure here so that we do not accidentally
89-
// report that the servable is available.
89+
if (!should_retry(absl::UnknownError(""))) {
90+
// Using UnknownError to check if the load is cancelled. If so, it means
91+
// Servable is going to be unloaded very soon, we report a failure here so
92+
// that we do not accidentally report that the servable is available.
9093
TF_RETURN_IF_ERROR(UnloadDueToCancelledLoad());
9194
return errors::Cancelled(
9295
strings::StrCat("Loading of servable cancelled"));
@@ -135,14 +138,15 @@ Status LoaderHarness::UnloadDueToCancelledLoad() {
135138
return UnloadInternal(State::kLoading);
136139
}
137140

138-
void LoaderHarness::set_cancel_load_retry(const bool value) {
141+
void LoaderHarness::set_should_retry(
142+
std::function<bool(absl::Status)> should_retry) {
139143
mutex_lock l(mu_);
140-
cancel_load_retry_ = value;
144+
should_retry_ = std::move(should_retry);
141145
}
142146

143-
bool LoaderHarness::cancel_load_retry() {
147+
bool LoaderHarness::should_retry(absl::Status status) {
144148
mutex_lock l(mu_);
145-
return cancel_load_retry_;
149+
return should_retry_(status);
146150
}
147151

148152
Status LoaderHarness::Unload() { return UnloadInternal(State::kQuiesced); }

tensorflow_serving/core/loader_harness.h

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ limitations under the License.
1818

1919
#include <memory>
2020

21+
#include "absl/status/status.h"
2122
#include "absl/types/optional.h"
2223
#include "tensorflow/core/lib/core/status.h"
2324
#include "tensorflow/core/platform/macros.h"
@@ -120,7 +121,8 @@ class LoaderHarness final {
120121
: id_(id),
121122
loader_(std::move(loader)),
122123
additional_state_(std::move(additional_state)),
123-
options_(options) {}
124+
options_(options),
125+
should_retry_([&](absl::Status status) { return true; }) {}
124126

125127
/// Legal to destruct iff current state is one of kNew, kDisabled or kError.
126128
/// Check-fails if violated.
@@ -168,14 +170,19 @@ class LoaderHarness final {
168170
/// method can be used to ensure that at most one Load() request can proceed.
169171
Status UnloadRequested() TF_LOCKS_EXCLUDED(mu_);
170172

171-
/// Cancels retrying the load of the servable. This is best-effort, and does
172-
/// not preempt a Load() which is already happening, only subsequent calls.
173+
/// Sets the retry behavior for the servable using a function which accepts
174+
/// the status of the last load attempt and returns a boolean. If the boolean
175+
/// is false, we cancel the next retry. This is best-effort, and does not
176+
/// preempt a Load() which is already happening, only subsequent calls.
173177
///
174178
/// If the retries are cancelled, the servable goes into a state dependent on
175179
/// the last Load() called on it. If the last Load() was successful, it will
176180
/// be in state kReady, else in kError.
177-
void set_cancel_load_retry(bool value) TF_LOCKS_EXCLUDED(mu_);
178-
bool cancel_load_retry() TF_LOCKS_EXCLUDED(mu_);
181+
void set_should_retry(std::function<bool(absl::Status)> should_retry)
182+
TF_LOCKS_EXCLUDED(mu_);
183+
184+
/// Returns true if the servable should be retried.
185+
bool should_retry(absl::Status status) TF_LOCKS_EXCLUDED(mu_);
179186

180187
/// Transitions to kUnloading, delegates to Servable::Unload(), then
181188
/// transitions to kDisabled when Unload() is done.
@@ -241,9 +248,11 @@ class LoaderHarness final {
241248
State state_ TF_GUARDED_BY(mu_) = State::kNew;
242249
// If state_ is kError, this will be non-OK.
243250
Status status_ TF_GUARDED_BY(mu_);
244-
// If set to true, we don't try to retry the load of the servable, if not
245-
// loaded by the first attempt.
246-
bool cancel_load_retry_ TF_GUARDED_BY(mu_) = false;
251+
// The retry policy for the servable. If the function returns false, we cancel
252+
// the next retry. This does not affect the current load action already
253+
// running.
254+
// There is no retry if the last action was successful.
255+
std::function<bool(absl::Status)> should_retry_ TF_GUARDED_BY(mu_);
247256

248257
TF_DISALLOW_COPY_AND_ASSIGN(LoaderHarness);
249258
};

0 commit comments

Comments
 (0)