Skip to content

Commit b01bb8c

Browse files
committed
review
1 parent a1873d9 commit b01bb8c

File tree

1 file changed

+42
-27
lines changed

1 file changed

+42
-27
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala

Lines changed: 42 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ case class ExpandExec(
168168
}
169169

170170
// Part 2: switch/case statements
171-
val cases = projections.zipWithIndex.map { case (exprs, row) =>
171+
val switchCaseExprs = projections.zipWithIndex.map { case (exprs, row) =>
172172
val (exprCodesWithIndices, inputVarSets) = exprs.indices.flatMap { col =>
173173
if (!sameOutput(col)) {
174174
val boundExpr = BindReferences.bindReference(exprs(col), attributeSeq)
@@ -180,40 +180,55 @@ case class ExpandExec(
180180
}
181181
}.unzip
182182

183-
val updateCode = exprCodesWithIndices.map { case (col, ev) =>
183+
val inputVars = inputVarSets.foldLeft(Set.empty[VariableValue])(_ ++ _)
184+
(row, exprCodesWithIndices, inputVars.toSeq)
185+
}
186+
187+
def generateUpdateCode(exprCodes: Seq[(Int, ExprCode)]): String = {
188+
exprCodes.map { case (col, ev) =>
184189
s"""
185190
|${ev.code}
186191
|${outputColumns(col).isNull} = ${ev.isNull};
187192
|${outputColumns(col).value} = ${ev.value};
188193
""".stripMargin
189-
}
194+
}.mkString("\n")
195+
}
190196

191-
val splitThreshold = SQLConf.get.methodSplitThreshold
192-
val inputVars = inputVarSets.foldLeft(Set.empty[VariableValue])(_ ++ _)
193-
val paramLength = CodeGenerator.calculateParamLengthFromExprValues(inputVars.toSeq)
194-
val maybeSplitUpdateCode = if (CodeGenerator.isValidParamLength(paramLength) &&
195-
exprCodesWithIndices.map(_._2.code.length).sum > splitThreshold) {
196-
val switchCaseFunc = ctx.freshName("switchCaseCode")
197-
val argList = inputVars.map { v =>
198-
s"${CodeGenerator.typeName(v.javaType)} ${v.variableName}"
197+
val splitThreshold = SQLConf.get.methodSplitThreshold
198+
val cases = if (switchCaseExprs.flatMap(_._2.map(_._2.code.length)).sum > splitThreshold) {
199+
switchCaseExprs.map { case (row, exprCodes, inputVars) =>
200+
val updateCode = generateUpdateCode(exprCodes)
201+
val paramLength = CodeGenerator.calculateParamLengthFromExprValues(inputVars)
202+
val maybeSplitUpdateCode = if (CodeGenerator.isValidParamLength(paramLength)) {
203+
val switchCaseFunc = ctx.freshName("switchCaseCode")
204+
val argList = inputVars.map { v =>
205+
s"${CodeGenerator.typeName(v.javaType)} ${v.variableName}"
206+
}
207+
ctx.addNewFunction(switchCaseFunc,
208+
s"""
209+
|private void $switchCaseFunc(${argList.mkString(", ")}) {
210+
| $updateCode
211+
|}
212+
""".stripMargin)
213+
214+
s"$switchCaseFunc(${inputVars.map(_.variableName).mkString(", ")});"
215+
} else {
216+
updateCode
199217
}
200-
ctx.addNewFunction(switchCaseFunc,
201-
s"""
202-
|private void $switchCaseFunc(${argList.mkString(", ")}) {
203-
| ${updateCode.mkString("\n")}
204-
|}
205-
""".stripMargin)
206-
207-
s"$switchCaseFunc(${inputVars.map(_.variableName).mkString(", ")});"
208-
} else {
209-
updateCode.mkString("\n")
218+
s"""
219+
|case $row:
220+
| $maybeSplitUpdateCode
221+
| break;
222+
""".stripMargin
223+
}
224+
} else {
225+
switchCaseExprs.map { case (row, exprCodes, _) =>
226+
s"""
227+
|case $row:
228+
| ${generateUpdateCode(exprCodes)}
229+
| break;
230+
""".stripMargin
210231
}
211-
212-
s"""
213-
|case $row:
214-
| $maybeSplitUpdateCode
215-
| break;
216-
""".stripMargin
217232
}
218233

219234
val numOutput = metricTerm(ctx, "numOutputRows")

0 commit comments

Comments
 (0)