Skip to content

Commit 84e24ea

Browse files
authored
Merge pull request #15500 from [BEAM-6721] Set numShards dynamically for TextIO.write()
[BEAM-6721] Set numShards dynamically for TextIO.write()
2 parents 46c649a + e096daf commit 84e24ea

2 files changed

Lines changed: 21 additions & 7 deletions

File tree

sdks/java/core/src/main/java/org/apache/beam/sdk/io/TextIO.java

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,6 @@ public static <UserT> TypedWrite<UserT, Void> writeCustomType() {
268268
.setDelimiter(new char[] {'\n'})
269269
.setWritableByteChannelFactory(FileBasedSink.CompressionType.UNCOMPRESSED)
270270
.setWindowedWrites(false)
271-
.setNumShards(0)
272271
.setNoSpilling(false)
273272
.build();
274273
}
@@ -623,7 +622,7 @@ public abstract static class TypedWrite<UserT, DestinationT>
623622
abstract @Nullable String getFooter();
624623

625624
/** Requested number of shards. 0 for automatic. */
626-
abstract int getNumShards();
625+
abstract @Nullable ValueProvider<Integer> getNumShards();
627626

628627
/** The shard template of each file written, combined with prefix and suffix. */
629628
abstract @Nullable String getShardTemplate();
@@ -689,7 +688,8 @@ abstract Builder<UserT, DestinationT> setDestinationFunction(
689688
abstract Builder<UserT, DestinationT> setFormatFunction(
690689
@Nullable SerializableFunction<UserT, String> formatFunction);
691690

692-
abstract Builder<UserT, DestinationT> setNumShards(int numShards);
691+
abstract Builder<UserT, DestinationT> setNumShards(
692+
@Nullable ValueProvider<Integer> numShards);
693693

694694
abstract Builder<UserT, DestinationT> setWindowedWrites(boolean windowedWrites);
695695

@@ -846,6 +846,14 @@ public TypedWrite<UserT, DestinationT> withSuffix(String filenameSuffix) {
846846
*/
847847
public TypedWrite<UserT, DestinationT> withNumShards(int numShards) {
848848
checkArgument(numShards >= 0);
849+
return withNumShards(StaticValueProvider.of(numShards));
850+
}
851+
852+
/**
853+
* Like {@link #withNumShards(int)}. Specifying {@code null} means runner-determined sharding.
854+
*/
855+
public TypedWrite<UserT, DestinationT> withNumShards(
856+
@Nullable ValueProvider<Integer> numShards) {
849857
return toBuilder().setNumShards(numShards).build();
850858
}
851859

@@ -1002,7 +1010,7 @@ public WriteFilesResult<DestinationT> expand(PCollection<UserT> input) {
10021010
getHeader(),
10031011
getFooter(),
10041012
getWritableByteChannelFactory()));
1005-
if (getNumShards() > 0) {
1013+
if (getNumShards() != null) {
10061014
write = write.withNumShards(getNumShards());
10071015
}
10081016
if (getWindowedWrites()) {
@@ -1020,8 +1028,8 @@ public void populateDisplayData(DisplayData.Builder builder) {
10201028

10211029
resolveDynamicDestinations().populateDisplayData(builder);
10221030
builder
1023-
.addIfNotDefault(
1024-
DisplayData.item("numShards", getNumShards()).withLabel("Maximum Output Shards"), 0)
1031+
.addIfNotNull(
1032+
DisplayData.item("numShards", getNumShards()).withLabel("Maximum Output Shards"))
10251033
.addIfNotNull(
10261034
DisplayData.item("tempDirectory", getTempDirectory())
10271035
.withLabel("Directory for temporary files"))
@@ -1139,6 +1147,11 @@ public Write withNumShards(int numShards) {
11391147
return new Write(inner.withNumShards(numShards));
11401148
}
11411149

1150+
/** See {@link TypedWrite#withNumShards(ValueProvider)}. */
1151+
public Write withNumShards(@Nullable ValueProvider<Integer> numShards) {
1152+
return new Write(inner.withNumShards(numShards));
1153+
}
1154+
11421155
/** See {@link TypedWrite#withoutSharding()}. */
11431156
public Write withoutSharding() {
11441157
return new Write(inner.withoutSharding());

sdks/java/core/src/main/java/org/apache/beam/sdk/io/WriteFiles.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -884,7 +884,8 @@ public void processElement(ProcessContext context) throws Exception {
884884
shardCount = context.sideInput(numShardsView);
885885
} else {
886886
checkNotNull(getNumShardsProvider());
887-
shardCount = getNumShardsProvider().get();
887+
shardCount =
888+
checkNotNull(getNumShardsProvider().get(), "Must have non-null number of shards.");
888889
}
889890
checkArgument(
890891
shardCount > 0,

0 commit comments

Comments
 (0)