6666import org .apache .beam .sdk .transforms .DoFn ;
6767import org .apache .beam .sdk .transforms .Filter ;
6868import org .apache .beam .sdk .transforms .GroupByKey ;
69+ import org .apache .beam .sdk .transforms .GroupIntoBatches ;
6970import org .apache .beam .sdk .transforms .PTransform ;
7071import org .apache .beam .sdk .transforms .ParDo ;
7172import org .apache .beam .sdk .transforms .Reshuffle ;
7273import org .apache .beam .sdk .transforms .SerializableFunction ;
7374import org .apache .beam .sdk .transforms .SerializableFunctions ;
75+ import org .apache .beam .sdk .transforms .Values ;
7476import org .apache .beam .sdk .transforms .View ;
7577import org .apache .beam .sdk .transforms .Wait ;
78+ import org .apache .beam .sdk .transforms .WithKeys ;
7679import org .apache .beam .sdk .transforms .display .DisplayData ;
7780import org .apache .beam .sdk .transforms .display .HasDisplayData ;
81+ import org .apache .beam .sdk .transforms .windowing .GlobalWindow ;
7882import org .apache .beam .sdk .util .BackOff ;
7983import org .apache .beam .sdk .util .BackOffUtils ;
8084import org .apache .beam .sdk .util .FluentBackoff ;
8185import org .apache .beam .sdk .util .Sleeper ;
8286import org .apache .beam .sdk .values .KV ;
8387import org .apache .beam .sdk .values .PBegin ;
8488import org .apache .beam .sdk .values .PCollection ;
89+ import org .apache .beam .sdk .values .PCollection .IsBounded ;
8590import org .apache .beam .sdk .values .PCollectionView ;
8691import org .apache .beam .sdk .values .PDone ;
8792import org .apache .beam .sdk .values .Row ;
96101import org .apache .commons .pool2 .impl .GenericObjectPoolConfig ;
97102import org .checkerframework .checker .nullness .qual .Nullable ;
98103import org .joda .time .Duration ;
104+ import org .joda .time .Instant ;
99105import org .slf4j .Logger ;
100106import org .slf4j .LoggerFactory ;
101107
@@ -1318,6 +1324,11 @@ public static class Write<T> extends PTransform<PCollection<T>, PDone> {
13181324 this .inner = inner ;
13191325 }
13201326
1327+ /** See {@link WriteVoid#withAutoSharding()}. */
1328+ public Write <T > withAutoSharding () {
1329+ return new Write <>(inner .withAutoSharding ());
1330+ }
1331+
13211332 /** See {@link WriteVoid#withDataSourceConfiguration(DataSourceConfiguration)}. */
13221333 public Write <T > withDataSourceConfiguration (DataSourceConfiguration config ) {
13231334 return new Write <>(inner .withDataSourceConfiguration (config ));
@@ -1393,6 +1404,7 @@ public <V extends JdbcWriteResult> WriteWithResults<T, V> withWriteResults(
13931404 .setPreparedStatementSetter (inner .getPreparedStatementSetter ())
13941405 .setStatement (inner .getStatement ())
13951406 .setTable (inner .getTable ())
1407+ .setAutoSharding (inner .getAutoSharding ())
13961408 .build ();
13971409 }
13981410
@@ -1408,6 +1420,50 @@ public PDone expand(PCollection<T> input) {
14081420 }
14091421 }
14101422
1423+ /* The maximum number of elements that will be included in a batch. */
1424+ private static final Integer MAX_BUNDLE_SIZE = 5000 ;
1425+
1426+ static <T > PCollection <Iterable <T >> batchElements (
1427+ PCollection <T > input , Boolean withAutoSharding ) {
1428+ PCollection <Iterable <T >> iterables ;
1429+ if (input .isBounded () == IsBounded .UNBOUNDED && withAutoSharding != null && withAutoSharding ) {
1430+ iterables =
1431+ input
1432+ .apply (WithKeys .<String , T >of ("" ))
1433+ .apply (
1434+ GroupIntoBatches .<String , T >ofSize (DEFAULT_BATCH_SIZE )
1435+ .withMaxBufferingDuration (Duration .millis (200 ))
1436+ .withShardedKey ())
1437+ .apply (Values .create ());
1438+ } else {
1439+ iterables =
1440+ input .apply (
1441+ ParDo .of (
1442+ new DoFn <T , Iterable <T >>() {
1443+ List <T > outputList ;
1444+
1445+ @ ProcessElement
1446+ public void process (ProcessContext c ) {
1447+ if (outputList == null ) {
1448+ outputList = new ArrayList <>();
1449+ }
1450+ outputList .add (c .element ());
1451+ if (outputList .size () > MAX_BUNDLE_SIZE ) {
1452+ c .output (outputList );
1453+ outputList = null ;
1454+ }
1455+ }
1456+
1457+ @ FinishBundle
1458+ public void finish (FinishBundleContext c ) {
1459+ c .output (outputList , Instant .now (), GlobalWindow .INSTANCE );
1460+ outputList = null ;
1461+ }
1462+ }));
1463+ }
1464+ return iterables ;
1465+ }
1466+
14111467 /** Interface implemented by functions that sets prepared statement data. */
14121468 @ FunctionalInterface
14131469 interface PreparedStatementSetCaller extends Serializable {
@@ -1430,6 +1486,8 @@ void set(
14301486 @ AutoValue
14311487 public abstract static class WriteWithResults <T , V extends JdbcWriteResult >
14321488 extends PTransform <PCollection <T >, PCollection <V >> {
1489+ abstract @ Nullable Boolean getAutoSharding ();
1490+
14331491 abstract @ Nullable SerializableFunction <Void , DataSource > getDataSourceProviderFn ();
14341492
14351493 abstract @ Nullable ValueProvider <String > getStatement ();
@@ -1451,6 +1509,8 @@ abstract static class Builder<T, V extends JdbcWriteResult> {
14511509 abstract Builder <T , V > setDataSourceProviderFn (
14521510 SerializableFunction <Void , DataSource > dataSourceProviderFn );
14531511
1512+ abstract Builder <T , V > setAutoSharding (Boolean autoSharding );
1513+
14541514 abstract Builder <T , V > setStatement (ValueProvider <String > statement );
14551515
14561516 abstract Builder <T , V > setPreparedStatementSetter (PreparedStatementSetter <T > setter );
@@ -1487,6 +1547,11 @@ public WriteWithResults<T, V> withPreparedStatementSetter(PreparedStatementSette
14871547 return toBuilder ().setPreparedStatementSetter (setter ).build ();
14881548 }
14891549
1550+ /** If true, enables using a dynamically determined number of shards to write. */
1551+ public WriteWithResults <T , V > withAutoSharding () {
1552+ return toBuilder ().setAutoSharding (true ).build ();
1553+ }
1554+
14901555 /**
14911556 * When a SQL exception occurs, {@link Write} uses this {@link RetryStrategy} to determine if it
14921557 * will retry the statements. If {@link RetryStrategy#apply(SQLException)} returns {@code true},
@@ -1549,8 +1614,14 @@ public PCollection<V> expand(PCollection<T> input) {
15491614 checkArgument (
15501615 (getDataSourceProviderFn () != null ),
15511616 "withDataSourceConfiguration() or withDataSourceProviderFn() is required" );
1617+ checkArgument (
1618+ getAutoSharding () == null
1619+ || (getAutoSharding () && input .isBounded () != IsBounded .UNBOUNDED ),
1620+ "Autosharding is only supported for streaming pipelines." );
1621+ ;
15521622
1553- return input .apply (
1623+ PCollection <Iterable <T >> iterables = JdbcIO .<T >batchElements (input , getAutoSharding ());
1624+ return iterables .apply (
15541625 ParDo .of (
15551626 new WriteFn <T , V >(
15561627 WriteFnSpec .builder ()
@@ -1573,6 +1644,8 @@ public PCollection<V> expand(PCollection<T> input) {
15731644 @ AutoValue
15741645 public abstract static class WriteVoid <T > extends PTransform <PCollection <T >, PCollection <Void >> {
15751646
1647+ abstract @ Nullable Boolean getAutoSharding ();
1648+
15761649 abstract @ Nullable SerializableFunction <Void , DataSource > getDataSourceProviderFn ();
15771650
15781651 abstract @ Nullable ValueProvider <String > getStatement ();
@@ -1591,6 +1664,8 @@ public abstract static class WriteVoid<T> extends PTransform<PCollection<T>, PCo
15911664
15921665 @ AutoValue .Builder
15931666 abstract static class Builder <T > {
1667+ abstract Builder <T > setAutoSharding (Boolean autoSharding );
1668+
15941669 abstract Builder <T > setDataSourceProviderFn (
15951670 SerializableFunction <Void , DataSource > dataSourceProviderFn );
15961671
@@ -1609,6 +1684,11 @@ abstract Builder<T> setDataSourceProviderFn(
16091684 abstract WriteVoid <T > build ();
16101685 }
16111686
1687+ /** If true, enables using a dynamically determined number of shards to write. */
1688+ public WriteVoid <T > withAutoSharding () {
1689+ return toBuilder ().setAutoSharding (true ).build ();
1690+ }
1691+
16121692 public WriteVoid <T > withDataSourceConfiguration (DataSourceConfiguration config ) {
16131693 return withDataSourceProviderFn (new DataSourceProviderFromDataSourceConfiguration (config ));
16141694 }
@@ -1708,7 +1788,10 @@ public PCollection<Void> expand(PCollection<T> input) {
17081788 checkArgument (
17091789 spec .getPreparedStatementSetter () != null , "withPreparedStatementSetter() is required" );
17101790 }
1711- return input
1791+
1792+ PCollection <Iterable <T >> iterables = JdbcIO .<T >batchElements (input , getAutoSharding ());
1793+
1794+ return iterables
17121795 .apply (
17131796 ParDo .of (
17141797 new WriteFn <T , Void >(
@@ -1955,7 +2038,7 @@ public void populateDisplayData(DisplayData.Builder builder) {
19552038 * @param <T>
19562039 * @param <V>
19572040 */
1958- static class WriteFn <T , V > extends DoFn <T , V > {
2041+ static class WriteFn <T , V > extends DoFn <Iterable < T > , V > {
19592042
19602043 @ AutoValue
19612044 abstract static class WriteFnSpec <T , V > implements Serializable , HasDisplayData {
@@ -2045,7 +2128,6 @@ abstract static class Builder<T, V> {
20452128 private Connection connection ;
20462129 private PreparedStatement preparedStatement ;
20472130 private static FluentBackoff retryBackOff ;
2048- private final List <T > records = new ArrayList <>();
20492131
20502132 public WriteFn (WriteFnSpec <T , V > spec ) {
20512133 this .spec = spec ;
@@ -2085,17 +2167,12 @@ private Connection getConnection() throws SQLException {
20852167
20862168 @ ProcessElement
20872169 public void processElement (ProcessContext context ) throws Exception {
2088- T record = context .element ();
2089- records .add (record );
2090- if (records .size () >= spec .getBatchSize ()) {
2091- executeBatch (context );
2092- }
2170+ executeBatch (context , context .element ());
20932171 }
20942172
20952173 @ FinishBundle
20962174 public void finishBundle () throws Exception {
20972175 // We pass a null context because we only execute a final batch for WriteVoid cases.
2098- executeBatch (null );
20992176 cleanUpStatementAndConnection ();
21002177 }
21012178
@@ -2124,11 +2201,8 @@ private void cleanUpStatementAndConnection() throws Exception {
21242201 }
21252202 }
21262203
2127- private void executeBatch (ProcessContext context )
2204+ private void executeBatch (ProcessContext context , Iterable < T > records )
21282205 throws SQLException , IOException , InterruptedException {
2129- if (records .isEmpty ()) {
2130- return ;
2131- }
21322206 Long startTimeNs = System .nanoTime ();
21332207 Sleeper sleeper = Sleeper .DEFAULT ;
21342208 BackOff backoff = retryBackOff .backoff ();
@@ -2137,16 +2211,18 @@ private void executeBatch(ProcessContext context)
21372211 getConnection ().prepareStatement (spec .getStatement ().get ())) {
21382212 try {
21392213 // add each record in the statement batch
2214+ int recordsInBatch = 0 ;
21402215 for (T record : records ) {
21412216 processRecord (record , preparedStatement , context );
2217+ recordsInBatch += 1 ;
21422218 }
21432219 if (!spec .getReturnResults ()) {
21442220 // execute the batch
21452221 preparedStatement .executeBatch ();
21462222 // commit the changes
21472223 getConnection ().commit ();
21482224 }
2149- RECORDS_PER_BATCH .update (records . size () );
2225+ RECORDS_PER_BATCH .update (recordsInBatch );
21502226 MS_PER_BATCH .update (TimeUnit .NANOSECONDS .toMillis (System .nanoTime () - startTimeNs ));
21512227 break ;
21522228 } catch (SQLException exception ) {
@@ -2164,7 +2240,6 @@ private void executeBatch(ProcessContext context)
21642240 }
21652241 }
21662242 }
2167- records .clear ();
21682243 }
21692244
21702245 private void processRecord (T record , PreparedStatement preparedStatement , ProcessContext c ) {
0 commit comments