Skip to content

Commit 021977f

Browse files
committed
Implemented another idea
1 parent 148b6d5 commit 021977f

File tree

9 files changed

+77
-94
lines changed

9 files changed

+77
-94
lines changed

python/pyspark/sql/tests.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1137,7 +1137,6 @@ def test_access_column(self):
11371137
self.assertTrue(isinstance(df['key'], Column))
11381138
self.assertTrue(isinstance(df[0], Column))
11391139
self.assertRaises(IndexError, lambda: df[2])
1140-
self.assertRaises(AnalysisException, lambda: df["bad_key"])
11411140
self.assertRaises(TypeError, lambda: df[{}])
11421141

11431142
def test_column_name_with_non_ascii(self):

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 38 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -603,15 +603,21 @@ class Analyzer(
603603
case q: LogicalPlan =>
604604
logTrace(s"Attempting to resolve ${q.simpleString}")
605605
q transformExpressionsUp {
606-
case u @ UnresolvedAttribute(nameParts) =>
606+
case u @ UnresolvedAttribute(nameParts, targetPlanIdOpt) =>
607607
// Leave unchanged if resolution fails. Hopefully will be resolved next round.
608608
val result =
609-
withPosition(u) { q.resolveChildren(nameParts, resolver).getOrElse(u) }
609+
withPosition(u) {
610+
targetPlanIdOpt match {
611+
case Some(targetPlanId) =>
612+
resolveExpressionFromSpecificLogicalPlan(nameParts, q, targetPlanId)
613+
case None =>
614+
q.resolveChildren(nameParts, resolver).getOrElse(u)
615+
}
616+
}
610617
logDebug(s"Resolving $u to $result")
611618
result
612619
case UnresolvedExtractValue(child, fieldExpr) if child.resolved =>
613620
ExtractValue(child, fieldExpr, resolver)
614-
case l: LazilyDeterminedAttribute => resolveLazilyDeterminedAttribute(l, q)
615621
}
616622
}
617623

@@ -684,22 +690,18 @@ class Analyzer(
684690
exprs.exists(_.find(_.isInstanceOf[UnresolvedDeserializer]).isDefined)
685691
}
686692

687-
private def resolveLazilyDeterminedAttribute(
688-
expr: LazilyDeterminedAttribute,
689-
plan: LogicalPlan): Expression = {
690-
691-
val foundPlanOpt = plan.findByBreadthFirst(_.planId == expr.plan.planId)
692-
val foundPlan = foundPlanOpt.getOrElse {
693-
failAnalysis(s"""Cannot resolve column name "${expr.name}" """)
694-
}
695-
696-
if (foundPlan == expr.plan) {
697-
expr.namedExpr
698-
} else {
699-
foundPlan.resolveQuoted(expr.name, resolver).getOrElse {
700-
failAnalysis(s"""Cannot resolve column name "${expr.name}" """ +
701-
s"""among (${foundPlan.schema.fieldNames.mkString(", ")})""")
702-
}
693+
private[sql] def resolveExpressionFromSpecificLogicalPlan(
694+
nameParts: Seq[String],
695+
planToSearchFrom: LogicalPlan,
696+
targetPlanId: Long): Expression = {
697+
lazy val name = UnresolvedAttribute(nameParts).name
698+
planToSearchFrom.findByBreadthFirst(_.planId == targetPlanId) match {
699+
case Some(foundPlan) =>
700+
foundPlan.resolve(nameParts, resolver).getOrElse {
701+
failAnalysis(s"Could not find $name in ${planToSearchFrom.output.mkString(", ")}")
702+
}
703+
case None =>
704+
failAnalysis(s"Could not find $name in ${planToSearchFrom.output.mkString(", ")}")
703705
}
704706
}
705707

@@ -714,11 +716,16 @@ class Analyzer(
714716
try {
715717
expr transformUp {
716718
case GetColumnByOrdinal(ordinal, _) => plan.output(ordinal)
717-
case u @ UnresolvedAttribute(nameParts) =>
718-
withPosition(u) { plan.resolve(nameParts, resolver).getOrElse(u) }
719+
case u @ UnresolvedAttribute(nameParts, targetPlanIdOpt) =>
720+
withPosition(u) {
721+
targetPlanIdOpt match {
722+
case Some(targetPlanId) =>
723+
resolveExpressionFromSpecificLogicalPlan(nameParts, plan, targetPlanId)
724+
case None => plan.resolve(nameParts, resolver).getOrElse(u)
725+
}
726+
}
719727
case UnresolvedExtractValue(child, fieldName) if child.resolved =>
720728
ExtractValue(child, fieldName, resolver)
721-
case l: LazilyDeterminedAttribute => resolveLazilyDeterminedAttribute(l, plan)
722729
}
723730
} catch {
724731
case a: AnalysisException if !throws => expr
@@ -942,12 +949,17 @@ class Analyzer(
942949
plan transformDown {
943950
case q: LogicalPlan if q.childrenResolved && !q.resolved =>
944951
q transformExpressions {
945-
case u @ UnresolvedAttribute(nameParts) =>
952+
case u @ UnresolvedAttribute(nameParts, targetPlanIdOpt) =>
946953
withPosition(u) {
947954
try {
948-
outer.resolve(nameParts, resolver) match {
949-
case Some(outerAttr) => OuterReference(outerAttr)
950-
case None => u
955+
targetPlanIdOpt match {
956+
case Some(targetPlanId) =>
957+
resolveExpressionFromSpecificLogicalPlan(nameParts, outer, targetPlanId)
958+
case None =>
959+
outer.resolve(nameParts, resolver) match {
960+
case Some(outerAttr) => OuterReference(outerAttr)
961+
case None => u
962+
}
951963
}
952964
} catch {
953965
case _: AnalysisException => u

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala

Lines changed: 3 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,9 @@ case class UnresolvedTableValuedFunction(functionName: String, functionArgs: Seq
8383
/**
8484
* Holds the name of an attribute that has yet to be resolved.
8585
*/
86-
case class UnresolvedAttribute(nameParts: Seq[String]) extends Attribute with Unevaluable {
86+
case class UnresolvedAttribute(
87+
nameParts: Seq[String],
88+
targetPlanIdOpt: Option[Long] = None) extends Attribute with Unevaluable {
8789

8890
def name: String =
8991
nameParts.map(n => if (n.contains(".")) s"`$n`" else n).mkString(".")
@@ -419,42 +421,3 @@ case class UnresolvedOrdinal(ordinal: Int)
419421
override def nullable: Boolean = throw new UnresolvedException(this, "nullable")
420422
override lazy val resolved = false
421423
}
422-
423-
/**
424-
* This is used when we refer a column like `df("expr")`
425-
* and determines which expression `df("expr")` should point to lazily.
426-
* Normally, `df("expr")` should point the expression (say expr1 here.) which
427-
* the logical plan in `df` outputs. but we have some cases that `df("expr")` should
428-
* point to another expression (say expr2 here) rather than expr1
429-
* and in this case, expr2 is equally to expr1 except exprId.
430-
* This will happen when datasets are self-joined or in similar situations and in this situation,
431-
* logical plans and expressions of those outputs are re-created with new exprIds the analyzer.
432-
* [[LazilyDeterminedAttribute()]] can treat this case properly
433-
* to determine that `df("expr")` should point which expression in the analyzer.
434-
*
435-
* @param namedExpr The expression which a column reference should point to normally.
436-
* @param plan The logical plan which contains the expression
437-
* which the column reference should point to lazily.
438-
*/
439-
case class LazilyDeterminedAttribute(
440-
namedExpr: NamedExpression)(
441-
val plan: LogicalPlan)
442-
extends Attribute with Unevaluable {
443-
// We need to keep the constructor curried
444-
// so that we can compare like df1("col1") == df2("col1") especially in case of test.
445-
446-
override def name: String = namedExpr.name
447-
override def exprId: ExprId = throw new UnresolvedException(this, "exprId")
448-
override def dataType: DataType = throw new UnresolvedException(this, "dataType")
449-
override def nullable: Boolean = throw new UnresolvedException(this, "nullable")
450-
override def qualifier: Option[String] = throw new UnresolvedException(this, "qualifier")
451-
override lazy val resolved = false
452-
453-
override def newInstance(): Attribute = throw new UnresolvedException(this, "newInstance")
454-
override def withNullability(newNullability: Boolean): Attribute =
455-
throw new UnresolvedException(this, "withNullability")
456-
override def withName(newName: String): Attribute =
457-
throw new UnresolvedException(this, "withName")
458-
override def withQualifier(newQualifier: Option[String]): Attribute =
459-
throw new UnresolvedException(this, "withQualifier")
460-
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ object ExpressionEncoder {
149149
} else {
150150
val input = GetColumnByOrdinal(index, enc.schema)
151151
val deserialized = enc.deserializer.transformUp {
152-
case UnresolvedAttribute(nameParts) =>
152+
case UnresolvedAttribute(nameParts, _) =>
153153
assert(nameParts.length == 1)
154154
UnresolvedExtractValue(input, Literal(nameParts.head))
155155
case GetColumnByOrdinal(ordinal, _) => GetStructField(input, ordinal)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1168,8 +1168,8 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
11681168
override def visitDereference(ctx: DereferenceContext): Expression = withOrigin(ctx) {
11691169
val attr = ctx.fieldName.getText
11701170
expression(ctx.base) match {
1171-
case UnresolvedAttribute(nameParts) =>
1172-
UnresolvedAttribute(nameParts :+ attr)
1171+
case UnresolvedAttribute(nameParts, targetPlanId) =>
1172+
UnresolvedAttribute(nameParts :+ attr, targetPlanId)
11731173
case e =>
11741174
UnresolvedExtractValue(e, Literal(attr))
11751175
}

sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
161161
if (f.dataType.isInstanceOf[NumericType] && cols.exists(col => columnEquals(f.name, col))) {
162162
fillCol[Double](f, value)
163163
} else {
164-
df.colInternal(f.name)
164+
df.col(f.name)
165165
}
166166
}
167167
df.select(projections : _*)
@@ -188,7 +188,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
188188
if (f.dataType.isInstanceOf[StringType] && cols.exists(col => columnEquals(f.name, col))) {
189189
fillCol[String](f, value)
190190
} else {
191-
df.colInternal(f.name)
191+
df.col(f.name)
192192
}
193193
}
194194
df.select(projections : _*)
@@ -363,7 +363,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
363363
} else if (f.dataType == targetColumnType && shouldReplace) {
364364
replaceCol(f, replacementMap)
365365
} else {
366-
df.colInternal(f.name)
366+
df.col(f.name)
367367
}
368368
}
369369
df.select(projections : _*)
@@ -395,7 +395,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
395395
case v: jl.Boolean => fillCol[Boolean](f, v.booleanValue())
396396
case v: String => fillCol[String](f, v)
397397
}
398-
}.getOrElse(df.colInternal(f.name))
398+
}.getOrElse(df.col(f.name))
399399
}
400400
df.select(projections : _*)
401401
}
@@ -407,8 +407,8 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
407407
val quotedColName = "`" + col.name + "`"
408408
val colValue = col.dataType match {
409409
case DoubleType | FloatType =>
410-
nanvl(df.colInternal(quotedColName), lit(null)) // nanvl only supports these types
411-
case _ => df.colInternal(quotedColName)
410+
nanvl(df.col(quotedColName), lit(null)) // nanvl only supports these types
411+
case _ => df.col(quotedColName)
412412
}
413413
coalesce(colValue, lit(replacement)).cast(col.dataType).as(col.name)
414414
}
@@ -420,8 +420,8 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
420420
* TODO: This can be optimized to use broadcast join when replacementMap is large.
421421
*/
422422
private def replaceCol(col: StructField, replacementMap: Map[_, _]): Column = {
423-
val keyExpr = df.colInternal(col.name).expr
424-
def buildExpr(v: Any) = Cast(Literal(v), keyExpr.dataType)
423+
val keyExpr = df.col(col.name).expr
424+
def buildExpr(v: Any) = Cast(Literal(v), col.dataType)
425425
val branches = replacementMap.flatMap { case (source, target) =>
426426
Seq(buildExpr(source), buildExpr(target))
427427
}.toSeq

sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -899,7 +899,7 @@ class Dataset[T] private[sql](
899899
*/
900900
@scala.annotation.varargs
901901
def sort(sortCol: String, sortCols: String*): Dataset[T] = {
902-
sort((sortCol +: sortCols).map(colInternal) : _*)
902+
sort((sortCol +: sortCols).map(apply) : _*)
903903
}
904904

905905
/**
@@ -953,8 +953,9 @@ class Dataset[T] private[sql](
953953
* @since 2.0.0
954954
*/
955955
def col(colName: String): Column = withStarResolved(colName) {
956-
val candidateExpr = resolve(colName)
957-
val expr = LazilyDeterminedAttribute(candidateExpr)(logicalPlan)
956+
val expr = UnresolvedAttribute(
957+
UnresolvedAttribute.parseAttributeName(colName),
958+
Some(queryExecution.analyzed.planId))
958959
Column(expr)
959960
}
960961

@@ -1703,8 +1704,7 @@ class Dataset[T] private[sql](
17031704
val convert = CatalystTypeConverters.createToCatalystConverter(dataType)
17041705
f(row(0).asInstanceOf[A]).map(o => InternalRow(convert(o)))
17051706
}
1706-
val generator =
1707-
UserDefinedGenerator(elementSchema, rowFunction, colInternal(inputColumn).expr :: Nil)
1707+
val generator = UserDefinedGenerator(elementSchema, rowFunction, apply(inputColumn).expr :: Nil)
17081708

17091709
withPlan {
17101710
Generate(generator, join = true, outer = false,
@@ -1832,15 +1832,17 @@ class Dataset[T] private[sql](
18321832
*/
18331833
def drop(col: Column): DataFrame = {
18341834
val expression = col match {
1835-
case Column(u: UnresolvedAttribute) =>
1836-
queryExecution.analyzed.resolveQuoted(
1837-
u.name, sparkSession.sessionState.analyzer.resolver).getOrElse(u)
1838-
case Column(l: LazilyDeterminedAttribute) =>
1839-
val foundExpression =
1840-
logicalPlan.findByBreadthFirst(_.planId == l.plan.planId)
1841-
.flatMap(_.resolveQuoted(l.name, sparkSession.sessionState.analyzer.resolver))
1842-
.getOrElse(l.namedExpr)
1843-
foundExpression
1835+
case Column(u @ UnresolvedAttribute(nameParts, targetPlanIdOpt)) =>
1836+
val plan = queryExecution.analyzed
1837+
val analyzer = sparkSession.sessionState.analyzer
1838+
val resolver = analyzer.resolver
1839+
1840+
targetPlanIdOpt match {
1841+
case Some(targetPlanId) =>
1842+
analyzer.resolveExpressionFromSpecificLogicalPlan(nameParts, plan, targetPlanId)
1843+
case None =>
1844+
plan.resolveQuoted(u.name, resolver).getOrElse(u)
1845+
}
18441846
case Column(expr: Expression) => expr
18451847
}
18461848
val attrs = this.logicalPlan.output
@@ -2633,6 +2635,9 @@ class Dataset[T] private[sql](
26332635
}
26342636
}
26352637

2638+
/** Another version of `col` which resolve an expression immediately.
2639+
* Mainly intended to use for test for example in case of passing columns to a SparkPlan.
2640+
*/
26362641
private[sql] def colInternal(colName: String): Column = withStarResolved(colName) {
26372642
val expr = resolve(colName)
26382643
Column(expr)

sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -607,7 +607,10 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
607607
Row(id, name, age, salary)
608608
}.toSeq)
609609
assert(df.schema.map(_.name) === Seq("id", "name", "age", "salary"))
610-
assert(df("id") == person("id"))
610+
val dfAnalyzer = df.sparkSession.sessionState.analyzer
611+
val personAnalyzer = person.sparkSession.sessionState.analyzer
612+
assert(dfAnalyzer.resolveExpression(df("id").expr, df.queryExecution.analyzed) ==
613+
personAnalyzer.resolveExpression(person("id").expr, person.queryExecution.analyzed))
611614
}
612615

613616
test("drop top level columns that contains dot") {
@@ -1469,6 +1472,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
14691472
join2.queryExecution.executedPlan.collect { case e: ReusedExchangeExec => true }.size === 4)
14701473
}
14711474
}
1475+
14721476
test("sameResult() on aggregate") {
14731477
val df = spark.range(100)
14741478
val agg1 = df.groupBy().count()

sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ object SparkPlanTest {
242242
case plan: SparkPlan =>
243243
val inputMap = plan.children.flatMap(_.output).map(a => (a.name, a)).toMap
244244
plan transformExpressions {
245-
case UnresolvedAttribute(Seq(u)) =>
245+
case UnresolvedAttribute(Seq(u), _) =>
246246
inputMap.getOrElse(u,
247247
sys.error(s"Invalid Test: Cannot resolve $u given input $inputMap"))
248248
}

0 commit comments

Comments
 (0)