Skip to content

Commit 3f1e6a1

Browse files
committed
Add new query hint NO_COLLAPSE.
1 parent 86d251c commit 3f1e6a1

File tree

6 files changed

+61
-15
lines changed

6 files changed

+61
-15
lines changed

python/pyspark/sql/functions.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -466,6 +466,14 @@ def nanvl(col1, col2):
466466
return Column(sc._jvm.functions.nanvl(_to_java_column(col1), _to_java_column(col2)))
467467

468468

469+
@since(2.2)
470+
def no_collapse(df):
471+
"""Marks a DataFrame as small enough for use in broadcast joins."""
472+
473+
sc = SparkContext._active_spark_context
474+
return DataFrame(sc._jvm.functions.no_collapse(df._jdf), df.sql_ctx)
475+
476+
469477
@since(1.4)
470478
def rand(seed=None):
471479
"""Generates a random column with independent and identically distributed (i.i.d.) samples

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -386,6 +386,13 @@ case class BroadcastHint(child: LogicalPlan) extends UnaryNode {
386386
child.stats(conf).copy(isBroadcastable = true)
387387
}
388388

389+
/**
390+
* A hint for the optimizer that we should not merge two projections.
391+
*/
392+
case class NoCollapseHint(child: LogicalPlan) extends UnaryNode {
393+
override def output: Seq[Attribute] = child.output
394+
}
395+
389396
/**
390397
* A general hint for the child. This node will be eliminated post analysis.
391398
* A pair of (name, parameters).

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._
2222
import org.apache.spark.sql.catalyst.dsl.plans._
2323
import org.apache.spark.sql.catalyst.expressions.Rand
2424
import org.apache.spark.sql.catalyst.plans.PlanTest
25-
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
25+
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, NoCollapseHint}
2626
import org.apache.spark.sql.catalyst.rules.RuleExecutor
2727

2828
class CollapseProjectSuite extends PlanTest {
@@ -119,4 +119,14 @@ class CollapseProjectSuite extends PlanTest {
119119

120120
comparePlans(optimized, correctAnswer)
121121
}
122+
123+
test("do not collapse projects with onceOnly expressions") {
124+
val query = NoCollapseHint(testRelation.select(('a * 10).as('a_times_10)))
125+
.select(('a_times_10 + 1).as('a_times_10_plus_1), ('a_times_10 + 2).as('a_times_10_plus_2))
126+
127+
val optimized = Optimize.execute(query.analyze)
128+
val correctAnswer = query.analyze
129+
130+
comparePlans(optimized, correctAnswer)
131+
}
122132
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -537,5 +537,10 @@ class PlanParserSuite extends PlanTest {
537537
comparePlans(
538538
parsePlan("SELECT /*+ MAPJOIN(t) */ a from t where true group by a order by a"),
539539
Hint("MAPJOIN", Seq("t"), table("t").where(Literal(true)).groupBy('a)('a)).orderBy('a.asc))
540+
541+
comparePlans(
542+
parsePlan("SELECT a FROM (SELECT /*+ NO_COLLAPSE */ * FROM t) t1"),
543+
SubqueryAlias("t1", Hint("NO_COLLAPSE", Seq.empty, table("t").select(star())))
544+
.select('a))
540545
}
541546
}

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -433,6 +433,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
433433
case r: LogicalRDD =>
434434
RDDScanExec(r.output, r.rdd, "ExistingRDD", r.outputPartitioning, r.outputOrdering) :: Nil
435435
case BroadcastHint(child) => planLater(child) :: Nil
436+
case NoCollapseHint(child) => planLater(child) :: Nil
436437
case _ => Nil
437438
}
438439
}

sql/core/src/main/scala/org/apache/spark/sql/functions.scala

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,17 +19,16 @@ package org.apache.spark.sql
1919

2020
import scala.collection.JavaConverters._
2121
import scala.language.implicitConversions
22-
import scala.reflect.runtime.universe.{typeTag, TypeTag}
22+
import scala.reflect.runtime.universe.{TypeTag, typeTag}
2323
import scala.util.Try
2424
import scala.util.control.NonFatal
25-
2625
import org.apache.spark.annotation.{Experimental, InterfaceStability}
2726
import org.apache.spark.sql.catalyst.ScalaReflection
2827
import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedFunction}
2928
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
3029
import org.apache.spark.sql.catalyst.expressions._
3130
import org.apache.spark.sql.catalyst.expressions.aggregate._
32-
import org.apache.spark.sql.catalyst.plans.logical.BroadcastHint
31+
import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, NoCollapseHint}
3332
import org.apache.spark.sql.execution.SparkSqlParser
3433
import org.apache.spark.sql.expressions.UserDefinedFunction
3534
import org.apache.spark.sql.internal.SQLConf
@@ -1007,21 +1006,37 @@ object functions {
10071006
def map(cols: Column*): Column = withExpr { CreateMap(cols.map(_.expr)) }
10081007

10091008
/**
1010-
* Marks a DataFrame as small enough for use in broadcast joins.
1011-
*
1012-
* The following example marks the right DataFrame for broadcast hash join using `joinKey`.
1013-
* {{{
1014-
* // left and right are DataFrames
1015-
* left.join(broadcast(right), "joinKey")
1016-
* }}}
1017-
*
1018-
* @group normal_funcs
1019-
* @since 1.5.0
1020-
*/
1009+
* Marks a DataFrame as small enough for use in broadcast joins.
1010+
*
1011+
* The following example marks the right DataFrame for broadcast hash join using `joinKey`.
1012+
* {{{
1013+
* // left and right are DataFrames
1014+
* left.join(broadcast(right), "joinKey")
1015+
* }}}
1016+
*
1017+
* @group normal_funcs
1018+
* @since 1.5.0
1019+
*/
10211020
def broadcast[T](df: Dataset[T]): Dataset[T] = {
10221021
Dataset[T](df.sparkSession, BroadcastHint(df.logicalPlan))(df.exprEnc)
10231022
}
10241023

1024+
/**
1025+
* Marks a DataFrame as small enough for use in broadcast joins.
1026+
*
1027+
* The following example marks the right DataFrame for broadcast hash join using `joinKey`.
1028+
* {{{
1029+
* // left and right are DataFrames
1030+
* left.join(broadcast(right), "joinKey")
1031+
* }}}
1032+
*
1033+
* @group normal_funcs
1034+
* @since 1.5.0
1035+
*/
1036+
def no_collapse[T](df: Dataset[T]): Dataset[T] = {
1037+
Dataset[T](df.sparkSession, NoCollapseHint(df.logicalPlan))(df.exprEnc)
1038+
}
1039+
10251040
/**
10261041
* Returns the first column that is not null, or null if all inputs are null.
10271042
*

0 commit comments

Comments
 (0)