Skip to content

Commit f20f196

Browse files
committed
Fix Join output nullabilities.
1 parent 6803642 commit f20f196

File tree

2 files changed

+43
-7
lines changed

2 files changed

+43
-7
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
package org.apache.spark.sql.catalyst.plans.logical
1919

2020
import org.apache.spark.sql.catalyst.expressions._
21-
import org.apache.spark.sql.catalyst.plans.{LeftSemi, JoinType}
21+
import org.apache.spark.sql.catalyst.plans._
2222
import org.apache.spark.sql.catalyst.types._
2323

2424
case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extends UnaryNode {
@@ -81,11 +81,28 @@ case class Join(
8181
condition: Option[Expression]) extends BinaryNode {
8282

8383
override def references = condition.map(_.references).getOrElse(Set.empty)
84-
override def output = joinType match {
85-
case LeftSemi =>
86-
left.output
87-
case _ =>
88-
left.output ++ right.output
84+
override def output = {
85+
def nullabilize(output: Seq[Attribute]) = {
86+
output.map {
87+
case attr if !attr.nullable =>
88+
AttributeReference(
89+
attr.name, attr.dataType, nullable = true)(attr.exprId, attr.qualifiers)
90+
case attr => attr
91+
}
92+
}
93+
94+
joinType match {
95+
case LeftSemi =>
96+
left.output
97+
case LeftOuter =>
98+
left.output ++ nullabilize(right.output)
99+
case RightOuter =>
100+
nullabilize(left.output) ++ right.output
101+
case FullOuter =>
102+
nullabilize(left.output) ++ nullabilize(right.output)
103+
case _ =>
104+
left.output ++ right.output
105+
}
89106
}
90107
}
91108

sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,26 @@ case class BroadcastNestedLoopJoin(
319319

320320
override def otherCopyArgs = sqlContext :: Nil
321321

322-
def output = left.output ++ right.output
322+
def output = {
323+
def nullabilize(output: Seq[Attribute]) = {
324+
output.map {
325+
case attr if !attr.nullable =>
326+
AttributeReference(attr.name, attr.dataType, nullable = true)(attr.exprId, attr.qualifiers)
327+
case attr => attr
328+
}
329+
}
330+
331+
joinType match {
332+
case LeftOuter =>
333+
left.output ++ nullabilize(right.output)
334+
case RightOuter =>
335+
nullabilize(left.output) ++ right.output
336+
case FullOuter =>
337+
nullabilize(left.output) ++ nullabilize(right.output)
338+
case _ =>
339+
left.output ++ right.output
340+
}
341+
}
323342

324343
/** The Streamed Relation */
325344
def left = streamed

0 commit comments

Comments
 (0)