Skip to content

Commit 818428f

Browse files
authored
Merge pull request apache#15863 from [BEAM-13184] Autosharding for JdbcIO.write* transforms
* Supporting autosharding on JdbcIO.write transforms * Making autosharding optional * Adding validation * integration test * Reducing code duplication * Adding a maximum bundle size to avoid overwhelming the memory
1 parent 0c2f5a5 commit 818428f

3 files changed

Lines changed: 160 additions & 16 deletions

File tree

sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcIO.java

Lines changed: 91 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -66,22 +66,27 @@
6666
import org.apache.beam.sdk.transforms.DoFn;
6767
import org.apache.beam.sdk.transforms.Filter;
6868
import org.apache.beam.sdk.transforms.GroupByKey;
69+
import org.apache.beam.sdk.transforms.GroupIntoBatches;
6970
import org.apache.beam.sdk.transforms.PTransform;
7071
import org.apache.beam.sdk.transforms.ParDo;
7172
import org.apache.beam.sdk.transforms.Reshuffle;
7273
import org.apache.beam.sdk.transforms.SerializableFunction;
7374
import org.apache.beam.sdk.transforms.SerializableFunctions;
75+
import org.apache.beam.sdk.transforms.Values;
7476
import org.apache.beam.sdk.transforms.View;
7577
import org.apache.beam.sdk.transforms.Wait;
78+
import org.apache.beam.sdk.transforms.WithKeys;
7679
import org.apache.beam.sdk.transforms.display.DisplayData;
7780
import org.apache.beam.sdk.transforms.display.HasDisplayData;
81+
import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
7882
import org.apache.beam.sdk.util.BackOff;
7983
import org.apache.beam.sdk.util.BackOffUtils;
8084
import org.apache.beam.sdk.util.FluentBackoff;
8185
import org.apache.beam.sdk.util.Sleeper;
8286
import org.apache.beam.sdk.values.KV;
8387
import org.apache.beam.sdk.values.PBegin;
8488
import org.apache.beam.sdk.values.PCollection;
89+
import org.apache.beam.sdk.values.PCollection.IsBounded;
8590
import org.apache.beam.sdk.values.PCollectionView;
8691
import org.apache.beam.sdk.values.PDone;
8792
import org.apache.beam.sdk.values.Row;
@@ -96,6 +101,7 @@
96101
import org.apache.commons.pool2.impl.GenericObjectPoolConfig;
97102
import org.checkerframework.checker.nullness.qual.Nullable;
98103
import org.joda.time.Duration;
104+
import org.joda.time.Instant;
99105
import org.slf4j.Logger;
100106
import org.slf4j.LoggerFactory;
101107

@@ -1318,6 +1324,11 @@ public static class Write<T> extends PTransform<PCollection<T>, PDone> {
13181324
this.inner = inner;
13191325
}
13201326

1327+
/** See {@link WriteVoid#withAutoSharding()}. */
1328+
public Write<T> withAutoSharding() {
1329+
return new Write<>(inner.withAutoSharding());
1330+
}
1331+
13211332
/** See {@link WriteVoid#withDataSourceConfiguration(DataSourceConfiguration)}. */
13221333
public Write<T> withDataSourceConfiguration(DataSourceConfiguration config) {
13231334
return new Write<>(inner.withDataSourceConfiguration(config));
@@ -1393,6 +1404,7 @@ public <V extends JdbcWriteResult> WriteWithResults<T, V> withWriteResults(
13931404
.setPreparedStatementSetter(inner.getPreparedStatementSetter())
13941405
.setStatement(inner.getStatement())
13951406
.setTable(inner.getTable())
1407+
.setAutoSharding(inner.getAutoSharding())
13961408
.build();
13971409
}
13981410

@@ -1408,6 +1420,50 @@ public PDone expand(PCollection<T> input) {
14081420
}
14091421
}
14101422

1423+
/* The maximum number of elements that will be included in a batch. */
1424+
private static final Integer MAX_BUNDLE_SIZE = 5000;
1425+
1426+
static <T> PCollection<Iterable<T>> batchElements(
1427+
PCollection<T> input, Boolean withAutoSharding) {
1428+
PCollection<Iterable<T>> iterables;
1429+
if (input.isBounded() == IsBounded.UNBOUNDED && withAutoSharding != null && withAutoSharding) {
1430+
iterables =
1431+
input
1432+
.apply(WithKeys.<String, T>of(""))
1433+
.apply(
1434+
GroupIntoBatches.<String, T>ofSize(DEFAULT_BATCH_SIZE)
1435+
.withMaxBufferingDuration(Duration.millis(200))
1436+
.withShardedKey())
1437+
.apply(Values.create());
1438+
} else {
1439+
iterables =
1440+
input.apply(
1441+
ParDo.of(
1442+
new DoFn<T, Iterable<T>>() {
1443+
List<T> outputList;
1444+
1445+
@ProcessElement
1446+
public void process(ProcessContext c) {
1447+
if (outputList == null) {
1448+
outputList = new ArrayList<>();
1449+
}
1450+
outputList.add(c.element());
1451+
if (outputList.size() > MAX_BUNDLE_SIZE) {
1452+
c.output(outputList);
1453+
outputList = null;
1454+
}
1455+
}
1456+
1457+
@FinishBundle
1458+
public void finish(FinishBundleContext c) {
1459+
c.output(outputList, Instant.now(), GlobalWindow.INSTANCE);
1460+
outputList = null;
1461+
}
1462+
}));
1463+
}
1464+
return iterables;
1465+
}
1466+
14111467
/** Interface implemented by functions that sets prepared statement data. */
14121468
@FunctionalInterface
14131469
interface PreparedStatementSetCaller extends Serializable {
@@ -1430,6 +1486,8 @@ void set(
14301486
@AutoValue
14311487
public abstract static class WriteWithResults<T, V extends JdbcWriteResult>
14321488
extends PTransform<PCollection<T>, PCollection<V>> {
1489+
abstract @Nullable Boolean getAutoSharding();
1490+
14331491
abstract @Nullable SerializableFunction<Void, DataSource> getDataSourceProviderFn();
14341492

14351493
abstract @Nullable ValueProvider<String> getStatement();
@@ -1451,6 +1509,8 @@ abstract static class Builder<T, V extends JdbcWriteResult> {
14511509
abstract Builder<T, V> setDataSourceProviderFn(
14521510
SerializableFunction<Void, DataSource> dataSourceProviderFn);
14531511

1512+
abstract Builder<T, V> setAutoSharding(Boolean autoSharding);
1513+
14541514
abstract Builder<T, V> setStatement(ValueProvider<String> statement);
14551515

14561516
abstract Builder<T, V> setPreparedStatementSetter(PreparedStatementSetter<T> setter);
@@ -1487,6 +1547,11 @@ public WriteWithResults<T, V> withPreparedStatementSetter(PreparedStatementSette
14871547
return toBuilder().setPreparedStatementSetter(setter).build();
14881548
}
14891549

1550+
/** If true, enables using a dynamically determined number of shards to write. */
1551+
public WriteWithResults<T, V> withAutoSharding() {
1552+
return toBuilder().setAutoSharding(true).build();
1553+
}
1554+
14901555
/**
14911556
* When a SQL exception occurs, {@link Write} uses this {@link RetryStrategy} to determine if it
14921557
* will retry the statements. If {@link RetryStrategy#apply(SQLException)} returns {@code true},
@@ -1549,8 +1614,14 @@ public PCollection<V> expand(PCollection<T> input) {
15491614
checkArgument(
15501615
(getDataSourceProviderFn() != null),
15511616
"withDataSourceConfiguration() or withDataSourceProviderFn() is required");
1617+
checkArgument(
1618+
getAutoSharding() == null
1619+
|| (getAutoSharding() && input.isBounded() != IsBounded.UNBOUNDED),
1620+
"Autosharding is only supported for streaming pipelines.");
1621+
;
15521622

1553-
return input.apply(
1623+
PCollection<Iterable<T>> iterables = JdbcIO.<T>batchElements(input, getAutoSharding());
1624+
return iterables.apply(
15541625
ParDo.of(
15551626
new WriteFn<T, V>(
15561627
WriteFnSpec.builder()
@@ -1573,6 +1644,8 @@ public PCollection<V> expand(PCollection<T> input) {
15731644
@AutoValue
15741645
public abstract static class WriteVoid<T> extends PTransform<PCollection<T>, PCollection<Void>> {
15751646

1647+
abstract @Nullable Boolean getAutoSharding();
1648+
15761649
abstract @Nullable SerializableFunction<Void, DataSource> getDataSourceProviderFn();
15771650

15781651
abstract @Nullable ValueProvider<String> getStatement();
@@ -1591,6 +1664,8 @@ public abstract static class WriteVoid<T> extends PTransform<PCollection<T>, PCo
15911664

15921665
@AutoValue.Builder
15931666
abstract static class Builder<T> {
1667+
abstract Builder<T> setAutoSharding(Boolean autoSharding);
1668+
15941669
abstract Builder<T> setDataSourceProviderFn(
15951670
SerializableFunction<Void, DataSource> dataSourceProviderFn);
15961671

@@ -1609,6 +1684,11 @@ abstract Builder<T> setDataSourceProviderFn(
16091684
abstract WriteVoid<T> build();
16101685
}
16111686

1687+
/** If true, enables using a dynamically determined number of shards to write. */
1688+
public WriteVoid<T> withAutoSharding() {
1689+
return toBuilder().setAutoSharding(true).build();
1690+
}
1691+
16121692
public WriteVoid<T> withDataSourceConfiguration(DataSourceConfiguration config) {
16131693
return withDataSourceProviderFn(new DataSourceProviderFromDataSourceConfiguration(config));
16141694
}
@@ -1708,7 +1788,10 @@ public PCollection<Void> expand(PCollection<T> input) {
17081788
checkArgument(
17091789
spec.getPreparedStatementSetter() != null, "withPreparedStatementSetter() is required");
17101790
}
1711-
return input
1791+
1792+
PCollection<Iterable<T>> iterables = JdbcIO.<T>batchElements(input, getAutoSharding());
1793+
1794+
return iterables
17121795
.apply(
17131796
ParDo.of(
17141797
new WriteFn<T, Void>(
@@ -1955,7 +2038,7 @@ public void populateDisplayData(DisplayData.Builder builder) {
19552038
* @param <T>
19562039
* @param <V>
19572040
*/
1958-
static class WriteFn<T, V> extends DoFn<T, V> {
2041+
static class WriteFn<T, V> extends DoFn<Iterable<T>, V> {
19592042

19602043
@AutoValue
19612044
abstract static class WriteFnSpec<T, V> implements Serializable, HasDisplayData {
@@ -2045,7 +2128,6 @@ abstract static class Builder<T, V> {
20452128
private Connection connection;
20462129
private PreparedStatement preparedStatement;
20472130
private static FluentBackoff retryBackOff;
2048-
private final List<T> records = new ArrayList<>();
20492131

20502132
public WriteFn(WriteFnSpec<T, V> spec) {
20512133
this.spec = spec;
@@ -2085,17 +2167,12 @@ private Connection getConnection() throws SQLException {
20852167

20862168
@ProcessElement
20872169
public void processElement(ProcessContext context) throws Exception {
2088-
T record = context.element();
2089-
records.add(record);
2090-
if (records.size() >= spec.getBatchSize()) {
2091-
executeBatch(context);
2092-
}
2170+
executeBatch(context, context.element());
20932171
}
20942172

20952173
@FinishBundle
20962174
public void finishBundle() throws Exception {
20972175
// We pass a null context because we only execute a final batch for WriteVoid cases.
2098-
executeBatch(null);
20992176
cleanUpStatementAndConnection();
21002177
}
21012178

@@ -2124,11 +2201,8 @@ private void cleanUpStatementAndConnection() throws Exception {
21242201
}
21252202
}
21262203

2127-
private void executeBatch(ProcessContext context)
2204+
private void executeBatch(ProcessContext context, Iterable<T> records)
21282205
throws SQLException, IOException, InterruptedException {
2129-
if (records.isEmpty()) {
2130-
return;
2131-
}
21322206
Long startTimeNs = System.nanoTime();
21332207
Sleeper sleeper = Sleeper.DEFAULT;
21342208
BackOff backoff = retryBackOff.backoff();
@@ -2137,16 +2211,18 @@ private void executeBatch(ProcessContext context)
21372211
getConnection().prepareStatement(spec.getStatement().get())) {
21382212
try {
21392213
// add each record in the statement batch
2214+
int recordsInBatch = 0;
21402215
for (T record : records) {
21412216
processRecord(record, preparedStatement, context);
2217+
recordsInBatch += 1;
21422218
}
21432219
if (!spec.getReturnResults()) {
21442220
// execute the batch
21452221
preparedStatement.executeBatch();
21462222
// commit the changes
21472223
getConnection().commit();
21482224
}
2149-
RECORDS_PER_BATCH.update(records.size());
2225+
RECORDS_PER_BATCH.update(recordsInBatch);
21502226
MS_PER_BATCH.update(TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTimeNs));
21512227
break;
21522228
} catch (SQLException exception) {
@@ -2164,7 +2240,6 @@ private void executeBatch(ProcessContext context)
21642240
}
21652241
}
21662242
}
2167-
records.clear();
21682243
}
21692244

21702245
private void processRecord(T record, PreparedStatement preparedStatement, ProcessContext c) {

sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOIT.java

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,17 @@
3232
import java.util.UUID;
3333
import java.util.function.Function;
3434
import org.apache.beam.sdk.PipelineResult;
35+
import org.apache.beam.sdk.coders.KvCoder;
36+
import org.apache.beam.sdk.coders.StringUtf8Coder;
37+
import org.apache.beam.sdk.coders.VarIntCoder;
3538
import org.apache.beam.sdk.io.GenerateSequence;
3639
import org.apache.beam.sdk.io.common.DatabaseTestHelper;
3740
import org.apache.beam.sdk.io.common.HashingFn;
3841
import org.apache.beam.sdk.io.common.PostgresIOTestPipelineOptions;
3942
import org.apache.beam.sdk.io.common.TestRow;
4043
import org.apache.beam.sdk.testing.PAssert;
4144
import org.apache.beam.sdk.testing.TestPipeline;
45+
import org.apache.beam.sdk.testing.TestStream;
4246
import org.apache.beam.sdk.testutils.NamedTestResult;
4347
import org.apache.beam.sdk.testutils.metrics.IOITMetrics;
4448
import org.apache.beam.sdk.testutils.metrics.MetricsReader;
@@ -51,6 +55,7 @@
5155
import org.apache.beam.sdk.transforms.Top;
5256
import org.apache.beam.sdk.values.KV;
5357
import org.apache.beam.sdk.values.PCollection;
58+
import org.joda.time.Instant;
5459
import org.junit.AfterClass;
5560
import org.junit.BeforeClass;
5661
import org.junit.Rule;
@@ -254,6 +259,40 @@ private PipelineResult runRead() {
254259
return pipelineRead.run();
255260
}
256261

262+
@Test
263+
public void testWriteWithAutosharding() throws Exception {
264+
String firstTableName = DatabaseTestHelper.getTestTableName("UT_WRITE");
265+
DatabaseTestHelper.createTable(dataSource, firstTableName);
266+
try {
267+
List<KV<Integer, String>> data = getTestDataToWrite(EXPECTED_ROW_COUNT);
268+
TestStream.Builder<KV<Integer, String>> ts =
269+
TestStream.create(KvCoder.of(VarIntCoder.of(), StringUtf8Coder.of()))
270+
.advanceWatermarkTo(Instant.now());
271+
for (KV<Integer, String> elm : data) {
272+
ts.addElements(elm);
273+
}
274+
275+
PCollection<KV<Integer, String>> dataCollection =
276+
pipelineWrite.apply(ts.advanceWatermarkToInfinity());
277+
dataCollection.apply(
278+
JdbcIO.<KV<Integer, String>>write()
279+
.withDataSourceProviderFn(voidInput -> dataSource)
280+
.withStatement(String.format("insert into %s values(?, ?) returning *", tableName))
281+
.withAutoSharding()
282+
.withPreparedStatementSetter(
283+
(element, statement) -> {
284+
statement.setInt(1, element.getKey());
285+
statement.setString(2, element.getValue());
286+
}));
287+
288+
pipelineWrite.run().waitUntilFinish();
289+
290+
runRead();
291+
} finally {
292+
DatabaseTestHelper.deleteTable(dataSource, firstTableName);
293+
}
294+
}
295+
257296
@Test
258297
public void testWriteWithWriteResults() throws Exception {
259298
String firstTableName = DatabaseTestHelper.getTestTableName("UT_WRITE");

sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOTest.java

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@
7373
import org.apache.beam.sdk.testing.ExpectedLogs;
7474
import org.apache.beam.sdk.testing.PAssert;
7575
import org.apache.beam.sdk.testing.TestPipeline;
76+
import org.apache.beam.sdk.testing.TestStream;
7677
import org.apache.beam.sdk.transforms.Count;
7778
import org.apache.beam.sdk.transforms.Create;
7879
import org.apache.beam.sdk.transforms.SerializableFunction;
@@ -87,6 +88,7 @@
8788
import org.hamcrest.TypeSafeMatcher;
8889
import org.joda.time.DateTime;
8990
import org.joda.time.Duration;
91+
import org.joda.time.Instant;
9092
import org.joda.time.LocalDate;
9193
import org.joda.time.chrono.ISOChronology;
9294
import org.junit.BeforeClass;
@@ -526,6 +528,31 @@ public void testWrite() throws Exception {
526528
}
527529
}
528530

531+
@Test
532+
public void testWriteWithAutosharding() throws Exception {
533+
String tableName = DatabaseTestHelper.getTestTableName("UT_WRITE");
534+
DatabaseTestHelper.createTable(DATA_SOURCE, tableName);
535+
TestStream.Builder<KV<Integer, String>> ts =
536+
TestStream.create(KvCoder.of(VarIntCoder.of(), StringUtf8Coder.of()))
537+
.advanceWatermarkTo(Instant.now());
538+
539+
try {
540+
List<KV<Integer, String>> data = getDataToWrite(EXPECTED_ROW_COUNT);
541+
for (KV<Integer, String> elm : data) {
542+
ts = ts.addElements(elm);
543+
}
544+
pipeline
545+
.apply(ts.advanceWatermarkToInfinity())
546+
.apply(getJdbcWrite(tableName).withAutoSharding());
547+
548+
pipeline.run().waitUntilFinish();
549+
550+
assertRowCount(DATA_SOURCE, tableName, EXPECTED_ROW_COUNT);
551+
} finally {
552+
DatabaseTestHelper.deleteTable(DATA_SOURCE, tableName);
553+
}
554+
}
555+
529556
@Test
530557
public void testWriteWithWriteResults() throws Exception {
531558
String firstTableName = DatabaseTestHelper.getTestTableName("UT_WRITE");
@@ -546,6 +573,9 @@ public void testWriteWithWriteResults() throws Exception {
546573
}));
547574
resultSetCollection.setCoder(JdbcTestHelper.TEST_DTO_CODER);
548575

576+
PAssert.thatSingleton(resultSetCollection.apply(Count.globally()))
577+
.isEqualTo((long) EXPECTED_ROW_COUNT);
578+
549579
List<JdbcTestHelper.TestDto> expectedResult = new ArrayList<>();
550580
for (int i = 0; i < EXPECTED_ROW_COUNT; i++) {
551581
expectedResult.add(new JdbcTestHelper.TestDto(JdbcTestHelper.TestDto.EMPTY_RESULT));

0 commit comments

Comments
 (0)