Skip to content

Commit 6fb9403

Browse files
Avoid GetChildren when using Specific servable versions
For some filesystem providers, like GCS, GetChildren does a lot more work that a simple FileExists call. This change special cases the `SPECIFIC` ServableVersionPolicy and does direct FileExists calls for each one. In the common case of a single version, this can be a single stat() call and avoid an expensive directory listing entirely. This optimization *only* applies when the versions and directories are equivalent to "base_dir/%d". So this fast path now happens before the GetChildren call, but will fall back to the general case of a directory listing when there are folders like: base_dir/ - 00001/ - 2/ and you want the specific version 1. Generally speaking, the support for strtod-ifying the string name is nice, but forces the directory listing. PiperOrigin-RevId: 627111947
1 parent 0fe7da7 commit 6fb9403

File tree

3 files changed

+146
-4
lines changed

3 files changed

+146
-4
lines changed

tensorflow_serving/sources/storage_path/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,8 @@ cc_library(
7474
"//tensorflow_serving/core:servable_id",
7575
"//tensorflow_serving/core:source",
7676
"//tensorflow_serving/core:storage_path",
77+
"@com_google_absl//absl/status",
78+
"@com_google_absl//absl/strings",
7779
"@com_google_absl//absl/types:variant",
7880
"@org_tensorflow//tensorflow/core:lib",
7981
"@org_tensorflow//tensorflow/core:tensorflow",

tensorflow_serving/sources/storage_path/file_system_storage_path_source.cc

Lines changed: 59 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,14 @@ limitations under the License.
2525
#include <utility>
2626
#include <vector>
2727

28+
#include "absl/status/status.h"
29+
#include "absl/strings/str_cat.h"
2830
#include "tensorflow/core/lib/core/errors.h"
2931
#include "tensorflow/core/lib/io/path.h"
3032
#include "tensorflow/core/lib/strings/numbers.h"
33+
#include "tensorflow/core/platform/env.h"
34+
#include "tsl/platform/errors.h"
35+
#include "tsl/platform/macros.h"
3136
#include "tensorflow_serving/core/servable_data.h"
3237
#include "tensorflow_serving/core/servable_id.h"
3338

@@ -159,6 +164,49 @@ bool AspireLatestVersions(
159164
return !children_by_version.empty();
160165
}
161166

167+
// Like `AspireSpecificVersions` but use `FileExists` instead of GetChildren to
168+
// remove unnecessary directory listings. Note that this function has to
169+
// fallback to the general case when there are directories that *parse as* the
170+
// version number via `strtod` but aren't equivalent (e.g., "base_dir/00001"
171+
// rather than "base_dir/1").
172+
//
173+
// Returns true if all the models are loaded.
174+
bool AspireSpecificVersionsFastPath(
175+
const FileSystemStoragePathSourceConfig::ServableToMonitor& servable,
176+
std::vector<ServableData<StoragePath>>* versions) {
177+
if (servable.servable_version_policy().specific().versions().empty()) {
178+
// There aren't any requested versions, WARN loudly and explicitly, since
179+
// this is a likely configuration error. Return *true*, since we are done
180+
// with processing this servable.
181+
LOG(WARNING) << "No specific versions requested for servable "
182+
<< servable.servable_name() << ".";
183+
return true;
184+
}
185+
186+
// First ensure that we find *all* the requested versions, so that we can use
187+
// this fast path. If not, we'll call the general AspireSpecificVersions after
188+
// a GetChildren call.
189+
for (const int64_t version :
190+
servable.servable_version_policy().specific().versions()) {
191+
const string version_dir = absl::StrCat(version);
192+
const string child_dir = io::JoinPath(servable.base_path(), version_dir);
193+
194+
const absl::Status status = Env::Default()->FileExists(child_dir);
195+
if (!status.ok()) {
196+
return false;
197+
}
198+
}
199+
200+
// We've found them all. Aspire them one by one.
201+
for (const int64_t version :
202+
servable.servable_version_policy().specific().versions()) {
203+
const string version_dir = absl::StrCat(version);
204+
AspireVersion(servable, version_dir, version, versions);
205+
}
206+
207+
return true;
208+
}
209+
162210
// Aspire versions for a servable configured with the "specific" version policy.
163211
//
164212
// 'children' represents a list of base-path children from the file system.
@@ -213,6 +261,16 @@ Status PollFileSystemForServable(
213261
servable.servable_name(), " with error ", status.ToString());
214262
}
215263

264+
if (servable.servable_version_policy().policy_choice_case() ==
265+
FileSystemStoragePathSourceConfig::ServableVersionPolicy::kSpecific) {
266+
// Special case the specific handler, to avoid GetChildren in the case where
267+
// all of the directories match their version number.
268+
if (AspireSpecificVersionsFastPath(servable, versions)) {
269+
// We found them all, exit early.
270+
return absl::OkStatus();
271+
}
272+
}
273+
216274
// Retrieve a list of base-path children from the file system.
217275
std::vector<string> children;
218276
TF_RETURN_IF_ERROR(
@@ -243,11 +301,10 @@ Status PollFileSystemForServable(
243301
at_least_one_version_found =
244302
AspireAllVersions(servable, children, versions);
245303
break;
246-
case FileSystemStoragePathSourceConfig::ServableVersionPolicy::kSpecific: {
304+
case FileSystemStoragePathSourceConfig::ServableVersionPolicy::kSpecific:
247305
at_least_one_version_found =
248306
AspireSpecificVersions(servable, children_by_version, versions);
249307
break;
250-
}
251308
default:
252309
return errors::Internal("Unhandled servable version_policy: ",
253310
servable.servable_version_policy().DebugString());

tensorflow_serving/sources/storage_path/file_system_storage_path_source_test.cc

Lines changed: 85 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,89 @@ TEST(FileSystemStoragePathSourceTest, SpecificVersions) {
314314
.PollFileSystemAndInvokeCallback());
315315
}
316316

317+
// This is the same as the `SpecificVersions` test above, but with leading zeros
318+
// on one of the directories to ensure we maintain the `strtod` property of
319+
// directory name => version number.
320+
TEST(FileSystemStoragePathSourceTest, SpecificVersionsLeadingZeros) {
321+
const string base_path =
322+
io::JoinPath(testing::TmpDir(), "SpecificVersionsLeadingZeros");
323+
TF_ASSERT_OK(Env::Default()->CreateDir(base_path));
324+
for (const string& version :
325+
{"non_numerical_child", "42", "33", "30", "21", "00017"}) {
326+
TF_ASSERT_OK(Env::Default()->CreateDir(io::JoinPath(base_path, version)));
327+
}
328+
329+
const FileSystemStoragePathSourceConfig config =
330+
test_util::CreateProto<FileSystemStoragePathSourceConfig>(
331+
strings::Printf("servables: { "
332+
" servable_version_policy { "
333+
" specific { "
334+
" versions: 17"
335+
" versions: 30"
336+
" } "
337+
" } "
338+
" servable_name: 'test_servable_name' "
339+
" base_path: '%s' "
340+
"} "
341+
// Disable the polling thread.
342+
"file_system_poll_wait_seconds: -1 ",
343+
base_path.c_str()));
344+
std::unique_ptr<FileSystemStoragePathSource> source;
345+
TF_ASSERT_OK(FileSystemStoragePathSource::Create(config, &source));
346+
std::unique_ptr<test_util::MockStoragePathTarget> target(
347+
new StrictMock<test_util::MockStoragePathTarget>);
348+
ConnectSourceToTarget(source.get(), target.get());
349+
350+
EXPECT_CALL(
351+
*target,
352+
SetAspiredVersions(
353+
Eq("test_servable_name"),
354+
ElementsAre(
355+
ServableData<StoragePath>({"test_servable_name", 17},
356+
io::JoinPath(base_path, "00017")),
357+
ServableData<StoragePath>({"test_servable_name", 30},
358+
io::JoinPath(base_path, "30")))));
359+
360+
TF_ASSERT_OK(internal::FileSystemStoragePathSourceTestAccess(source.get())
361+
.PollFileSystemAndInvokeCallback());
362+
}
363+
364+
TEST(FileSystemStoragePathSourceTest, SpecificVersionsEmpty) {
365+
const string base_path =
366+
io::JoinPath(testing::TmpDir(), "SpecificVersionsEmpty");
367+
TF_ASSERT_OK(Env::Default()->CreateDir(base_path));
368+
for (const string& version :
369+
{"non_numerical_child", "42", "33", "30", "21", "17"}) {
370+
TF_ASSERT_OK(Env::Default()->CreateDir(io::JoinPath(base_path, version)));
371+
}
372+
373+
const FileSystemStoragePathSourceConfig config =
374+
test_util::CreateProto<FileSystemStoragePathSourceConfig>(
375+
strings::Printf("servables: { "
376+
" servable_version_policy { "
377+
" specific { "
378+
" } "
379+
" } "
380+
" servable_name: 'test_servable_name' "
381+
" base_path: '%s' "
382+
"} "
383+
// Disable the polling thread.
384+
"file_system_poll_wait_seconds: -1 ",
385+
base_path.c_str()));
386+
std::unique_ptr<FileSystemStoragePathSource> source;
387+
TF_ASSERT_OK(FileSystemStoragePathSource::Create(config, &source));
388+
std::unique_ptr<test_util::MockStoragePathTarget> target(
389+
new StrictMock<test_util::MockStoragePathTarget>);
390+
ConnectSourceToTarget(source.get(), target.get());
391+
392+
// The servable has no requested versions, but we still want to call
393+
// SetAspiredVersions with an empty list for consistency.
394+
EXPECT_CALL(*target, SetAspiredVersions(Eq("test_servable_name"), IsEmpty()));
395+
396+
TF_ASSERT_OK(internal::FileSystemStoragePathSourceTestAccess(source.get())
397+
.PollFileSystemAndInvokeCallback());
398+
}
399+
317400
TEST(FileSystemStoragePathSourceTest, DefaultVersionPolicy) {
318401
// Validate that default version policy is to serve the latest servable
319402
// version.
@@ -512,7 +595,7 @@ TEST(FileSystemStoragePathSourceTest, ChangeVersionPolicy) {
512595
const string base_path_prefix =
513596
io::JoinPath(testing::TmpDir(), "ChangeVersionPolicy_");
514597
TF_ASSERT_OK(Env::Default()->CreateDir(base_path_prefix));
515-
for (const string& version : {"1", "2", "3", "5", "8", "13"}) {
598+
for (const string& version : {"1", "02", "3", "5", "8", "13"}) {
516599
TF_ASSERT_OK(
517600
Env::Default()->CreateDir(io::JoinPath(base_path_prefix, version)));
518601
}
@@ -572,7 +655,7 @@ TEST(FileSystemStoragePathSourceTest, ChangeVersionPolicy) {
572655
Eq("test_servable_name"),
573656
ElementsAre(
574657
ServableData<StoragePath>({"test_servable_name", 2},
575-
io::JoinPath(base_path_prefix, "2")),
658+
io::JoinPath(base_path_prefix, "02")),
576659
ServableData<StoragePath>({"test_servable_name", 5},
577660
io::JoinPath(base_path_prefix, "5")))));
578661

0 commit comments

Comments
 (0)