@@ -24,12 +24,16 @@ limitations under the License.
2424#include < utility>
2525#include < vector>
2626
27+ #include " absl/status/status.h"
2728#include " tensorflow/core/lib/core/errors.h"
2829#include " tensorflow/core/lib/strings/strcat.h"
2930#include " tensorflow/core/platform/logging.h"
3031#include " tensorflow/core/platform/macros.h"
3132#include " tensorflow_serving/core/servable_handle.h"
33+ #include " tensorflow_serving/core/servable_state.h"
3234#include " tensorflow_serving/core/source.h"
35+ #include " tensorflow_serving/resources/resource_tracker.h"
36+ #include " tensorflow_serving/util/event_bus.h"
3337#include " tensorflow_serving/util/hash.h"
3438#include " tensorflow_serving/util/inline_executor.h"
3539#include " tensorflow_serving/util/retrier.h"
@@ -225,21 +229,23 @@ Status BasicManager::Create(Options options,
225229 std::unique_ptr<BasicManager>* manager) {
226230 manager->reset (new BasicManager (
227231 options.env , options.num_load_threads , options.num_unload_threads ,
228- options.max_num_load_retries , options.load_retry_interval_micros ,
229- options.flush_filesystem_caches , std::move (options.resource_tracker ),
230- options.servable_event_bus , std::move (options.pre_load_hook )));
232+ options.max_num_load_retries , std::move (options.should_retry_model_load ),
233+ options.load_retry_interval_micros , options.flush_filesystem_caches ,
234+ std::move (options.resource_tracker ), options.servable_event_bus ,
235+ std::move (options.pre_load_hook )));
231236 return OkStatus ();
232237}
233238
234- BasicManager::BasicManager (Env* const env, const uint32 num_load_threads,
235- const uint32 num_unload_threads ,
236- uint32 max_num_load_retries,
237- int64_t load_retry_interval_micros ,
238- bool flush_filesystem_caches,
239- std::unique_ptr<ResourceTracker> resource_tracker,
240- EventBus<ServableState>* servable_event_bus,
241- std::function<void (const ServableId&)> pre_load_hook)
239+ BasicManager::BasicManager (
240+ Env* const env, const uint32 num_load_threads ,
241+ const uint32 num_unload_threads, uint32 max_num_load_retries,
242+ std::function< bool (absl::Status)> should_retry_model_load ,
243+ int64_t load_retry_interval_micros, bool flush_filesystem_caches,
244+ std::unique_ptr<ResourceTracker> resource_tracker,
245+ EventBus<ServableState>* servable_event_bus,
246+ std::function<void (const ServableId&)> pre_load_hook)
242247 : servable_event_bus_(servable_event_bus),
248+ should_retry_model_load_ (std::move(should_retry_model_load)),
243249 env_(env),
244250 num_load_threads_(num_load_threads),
245251 flush_filesystem_caches_(flush_filesystem_caches),
@@ -357,6 +363,9 @@ Status BasicManager::ManageServableInternal(
357363
358364 std::shared_ptr<LoaderHarness> harness =
359365 harness_creator (servable.id (), std::move (loader));
366+ if (should_retry_model_load_) {
367+ harness->set_should_retry (should_retry_model_load_);
368+ }
360369 if (!servable.status ().ok ()) {
361370 harness->Error (servable.status ());
362371 } else {
@@ -527,7 +536,7 @@ void BasicManager::CancelLoadServableRetry(const ServableId& id) {
527536 if (!status.ok ()) {
528537 return ;
529538 }
530- harness->set_cancel_load_retry ( true );
539+ harness->set_should_retry ([](absl::Status status) { return false ; } );
531540}
532541
533542Status BasicManager::ExecuteUnload (LoaderHarness* harness) {
@@ -748,7 +757,7 @@ Status BasicManager::ReserveResources(LoaderHarness* harness,
748757 return resource_tracker_->ReserveResources (*harness->loader (),
749758 &resources_reserved);
750759 },
751- [&]() { return harness->cancel_load_retry ( ); });
760+ [&](absl::Status status ) { return harness->should_retry (status ); });
752761 if (!reserve_resources_status.ok ()) {
753762 return Status (
754763 reserve_resources_status.code (),
0 commit comments