Skip to content

Commit a17d933

Browse files
committed
Add a new rule to collapse multiple concats in Optimizer
1 parent df3869a commit a17d933

File tree

7 files changed

+149
-18
lines changed

7 files changed

+149
-18
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,8 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: SQLConf)
111111
RemoveRedundantProject,
112112
SimplifyCreateStructOps,
113113
SimplifyCreateArrayOps,
114-
SimplifyCreateMapOps) ++
114+
SimplifyCreateMapOps,
115+
CollapseConcat) ++
115116
extendedOperatorOptimizationRules: _*) ::
116117
Batch("Check Cartesian Products", Once,
117118
CheckCartesianProducts(conf)) ::

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.sql.catalyst.optimizer
1919

2020
import scala.collection.immutable.HashSet
21+
import scala.collection.mutable.{ArrayBuffer, Stack}
2122

2223
import org.apache.spark.sql.catalyst.analysis._
2324
import org.apache.spark.sql.catalyst.expressions._
@@ -543,3 +544,27 @@ object SimplifyCaseConversionExpressions extends Rule[LogicalPlan] {
543544
}
544545
}
545546
}
547+
548+
/**
549+
* Collapse nested [[Concat]] expressions.
550+
*/
551+
object CollapseConcat extends Rule[LogicalPlan] {
552+
def apply(plan: LogicalPlan): LogicalPlan = plan.transformExpressionsDown {
553+
case concat: Concat if concat.children.exists(_.isInstanceOf[Concat]) =>
554+
val flatExprs = ArrayBuffer[Expression]()
555+
val waitingForProcess = new Stack[Seq[Expression]]
556+
waitingForProcess.push(concat.children)
557+
while (waitingForProcess.nonEmpty) {
558+
val toProcess = waitingForProcess.pop()
559+
val (head, tail) = toProcess.span(!_.isInstanceOf[Concat])
560+
flatExprs ++= head
561+
tail.headOption.foreach { case concat: Concat =>
562+
if (tail.size > 1) {
563+
waitingForProcess.push(tail.tail)
564+
}
565+
waitingForProcess.push(concat.children)
566+
}
567+
}
568+
concat.copy(children = flatExprs.toSeq)
569+
}
570+
}
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.optimizer
19+
20+
import org.apache.spark.sql.catalyst.dsl.plans._
21+
import org.apache.spark.sql.catalyst.expressions._
22+
import org.apache.spark.sql.catalyst.plans.PlanTest
23+
import org.apache.spark.sql.catalyst.plans.logical._
24+
import org.apache.spark.sql.catalyst.rules._
25+
import org.apache.spark.sql.types.StringType
26+
27+
28+
class CollapseConcatSuite extends PlanTest with PredicateHelper {
29+
30+
object Optimize extends RuleExecutor[LogicalPlan] {
31+
val batches = Batch("CollapseConcatSuite", FixedPoint(50), CollapseConcat) :: Nil
32+
}
33+
34+
protected def assertEquivalent(e1: Expression, e2: Expression): Unit = {
35+
val correctAnswer = Project(Alias(e2, "out")() :: Nil, OneRowRelation).analyze
36+
val actual = Optimize.execute(Project(Alias(e1, "out")() :: Nil, OneRowRelation).analyze)
37+
comparePlans(actual, correctAnswer)
38+
}
39+
40+
test("collapse nested Concat exprs") {
41+
def str(s: String): Literal = Literal(s, StringType)
42+
assertEquivalent(
43+
Concat(
44+
Concat(str("a") :: str("b") :: Nil) ::
45+
str("c") ::
46+
str("d") ::
47+
Nil),
48+
Concat(str("a") :: str("b") :: str("c") :: str("d") :: Nil))
49+
assertEquivalent(
50+
Concat(
51+
str("a") ::
52+
Concat(str("b") :: str("c") :: Nil) ::
53+
str("d") ::
54+
Nil),
55+
Concat(str("a") :: str("b") :: str("c") :: str("d") :: Nil))
56+
assertEquivalent(
57+
Concat(
58+
str("a") ::
59+
str("b") ::
60+
Concat(str("c") :: str("d") :: Nil) ::
61+
Nil),
62+
Concat(str("a") :: str("b") :: str("c") :: str("d") :: Nil))
63+
assertEquivalent(
64+
Concat(
65+
Concat(
66+
str("a") ::
67+
Concat(
68+
str("b") ::
69+
Concat(str("c") :: str("d") :: Nil) ::
70+
Nil) ::
71+
Nil) ::
72+
Nil),
73+
Concat(str("a") :: str("b") :: str("c") :: str("d") :: Nil))
74+
}
75+
}

sql/core/src/test/resources/sql-tests/inputs/operators.sql

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,9 @@ select 5 % 3;
3434
select pmod(-7, 3);
3535

3636
-- check operator precedence
37-
explain select 'a' || 1 + 2;
38-
explain select 1 - 2 || 'b';
39-
explain select 2 * 4 + 3 || 'b';
40-
explain select 3 + 1 || 'a' || 4 / 2;
41-
explain select 1 == 1 OR 'a' || 'b' == 'ab';
42-
explain select 'a' || 'c' == 'ac' AND 2 == 3;
37+
EXPLAIN SELECT 'a' || 1 + 2;
38+
EXPLAIN SELECT 1 - 2 || 'b';
39+
EXPLAIN SELECT 2 * 4 + 3 || 'b';
40+
EXPLAIN SELECT 3 + 1 || 'a' || 4 / 2;
41+
EXPLAIN SELECT 1 == 1 OR 'a' || 'b' == 'ab';
42+
EXPLAIN SELECT 'a' || 'c' == 'ac' AND 2 == 3;

sql/core/src/test/resources/sql-tests/inputs/string-functions.sql

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,8 @@ select concat_ws();
33
select format_string();
44

55
-- A pipe operator for string concatenation
6-
select 'a' || 'b' || 'c';
6+
SELECT 'a' || 'b';
7+
8+
-- Check if catalyst collapses multiple `Concat`s
9+
EXPLAIN EXTENDED SELECT (col1 || col2 || col3 || col4) col
10+
FROM (SELECT id col1, id col2, id col3, id col4 FROM range(10));

sql/core/src/test/resources/sql-tests/results/operators.sql.out

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ struct<pmod(-7, 3):int>
227227

228228

229229
-- !query 28
230-
explain select 'a' || 1 + 2
230+
EXPLAIN SELECT 'a' || 1 + 2
231231
-- !query 28 schema
232232
struct<plan:string>
233233
-- !query 28 output
@@ -237,7 +237,7 @@ struct<plan:string>
237237

238238

239239
-- !query 29
240-
explain select 1 - 2 || 'b'
240+
EXPLAIN SELECT 1 - 2 || 'b'
241241
-- !query 29 schema
242242
struct<plan:string>
243243
-- !query 29 output
@@ -247,7 +247,7 @@ struct<plan:string>
247247

248248

249249
-- !query 30
250-
explain select 2 * 4 + 3 || 'b'
250+
EXPLAIN SELECT 2 * 4 + 3 || 'b'
251251
-- !query 30 schema
252252
struct<plan:string>
253253
-- !query 30 output
@@ -257,7 +257,7 @@ struct<plan:string>
257257

258258

259259
-- !query 31
260-
explain select 3 + 1 || 'a' || 4 / 2
260+
EXPLAIN SELECT 3 + 1 || 'a' || 4 / 2
261261
-- !query 31 schema
262262
struct<plan:string>
263263
-- !query 31 output
@@ -267,7 +267,7 @@ struct<plan:string>
267267

268268

269269
-- !query 32
270-
explain select 1 == 1 OR 'a' || 'b' == 'ab'
270+
EXPLAIN SELECT 1 == 1 OR 'a' || 'b' == 'ab'
271271
-- !query 32 schema
272272
struct<plan:string>
273273
-- !query 32 output
@@ -277,7 +277,7 @@ struct<plan:string>
277277

278278

279279
-- !query 33
280-
explain select 'a' || 'c' == 'ac' AND 2 == 3
280+
EXPLAIN SELECT 'a' || 'c' == 'ac' AND 2 == 3
281281
-- !query 33 schema
282282
struct<plan:string>
283283
-- !query 33 output

sql/core/src/test/resources/sql-tests/results/string-functions.sql.out

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
-- Automatically generated by SQLQueryTestSuite
2-
-- Number of queries: 3
2+
-- Number of queries: 4
33

44

55
-- !query 0
@@ -21,8 +21,34 @@ requirement failed: format_string() should take at least 1 argument; line 1 pos
2121

2222

2323
-- !query 2
24-
select 'a' || 'b' || 'c'
24+
SELECT 'a' || 'b'
2525
-- !query 2 schema
26-
struct<concat(concat(a, b), c):string>
26+
struct<concat(a, b):string>
2727
-- !query 2 output
28-
abc
28+
ab
29+
30+
31+
-- !query 3
32+
EXPLAIN EXTENDED SELECT (col1 || col2 || col3 || col4) col
33+
FROM (SELECT id col1, id col2, id col3, id col4 FROM range(10))
34+
-- !query 3 schema
35+
struct<plan:string>
36+
-- !query 3 output
37+
== Parsed Logical Plan ==
38+
'Project [concat(concat(concat('col1, 'col2), 'col3), 'col4) AS col#x]
39+
+- 'Project ['id AS col1#x, 'id AS col2#x, 'id AS col3#x, 'id AS col4#x]
40+
+- 'UnresolvedTableValuedFunction range, [10]
41+
42+
== Analyzed Logical Plan ==
43+
col: string
44+
Project [concat(concat(concat(cast(col1#xL as string), cast(col2#xL as string)), cast(col3#xL as string)), cast(col4#xL as string)) AS col#x]
45+
+- Project [id#xL AS col1#xL, id#xL AS col2#xL, id#xL AS col3#xL, id#xL AS col4#xL]
46+
+- Range (0, 10, step=1, splits=None)
47+
48+
== Optimized Logical Plan ==
49+
Project [concat(cast(id#xL as string), cast(id#xL as string), cast(id#xL as string), cast(id#xL as string)) AS col#x]
50+
+- Range (0, 10, step=1, splits=None)
51+
52+
== Physical Plan ==
53+
*Project [concat(cast(id#xL as string), cast(id#xL as string), cast(id#xL as string), cast(id#xL as string)) AS col#x]
54+
+- *Range (0, 10, step=1, splits=2)

0 commit comments

Comments
 (0)