@@ -168,7 +168,7 @@ case class ExpandExec(
168168 }
169169
170170 // Part 2: switch/case statements
171- val cases = projections.zipWithIndex.map { case (exprs, row) =>
171+ val switchCaseExprs = projections.zipWithIndex.map { case (exprs, row) =>
172172 val (exprCodesWithIndices, inputVarSets) = exprs.indices.flatMap { col =>
173173 if (! sameOutput(col)) {
174174 val boundExpr = BindReferences .bindReference(exprs(col), attributeSeq)
@@ -180,40 +180,55 @@ case class ExpandExec(
180180 }
181181 }.unzip
182182
183- val updateCode = exprCodesWithIndices.map { case (col, ev) =>
183+ val inputVars = inputVarSets.foldLeft(Set .empty[VariableValue ])(_ ++ _)
184+ (row, exprCodesWithIndices, inputVars.toSeq)
185+ }
186+
187+ def generateUpdateCode (exprCodes : Seq [(Int , ExprCode )]): String = {
188+ exprCodes.map { case (col, ev) =>
184189 s """
185190 | ${ev.code}
186191 | ${outputColumns(col).isNull} = ${ev.isNull};
187192 | ${outputColumns(col).value} = ${ev.value};
188193 """ .stripMargin
189- }
194+ }.mkString(" \n " )
195+ }
190196
191- val splitThreshold = SQLConf .get.methodSplitThreshold
192- val inputVars = inputVarSets.foldLeft(Set .empty[VariableValue ])(_ ++ _)
193- val paramLength = CodeGenerator .calculateParamLengthFromExprValues(inputVars.toSeq)
194- val maybeSplitUpdateCode = if (CodeGenerator .isValidParamLength(paramLength) &&
195- exprCodesWithIndices.map(_._2.code.length).sum > splitThreshold) {
196- val switchCaseFunc = ctx.freshName(" switchCaseCode" )
197- val argList = inputVars.map { v =>
198- s " ${CodeGenerator .typeName(v.javaType)} ${v.variableName}"
197+ val splitThreshold = SQLConf .get.methodSplitThreshold
198+ val cases = if (switchCaseExprs.flatMap(_._2.map(_._2.code.length)).sum > splitThreshold) {
199+ switchCaseExprs.map { case (row, exprCodes, inputVars) =>
200+ val updateCode = generateUpdateCode(exprCodes)
201+ val paramLength = CodeGenerator .calculateParamLengthFromExprValues(inputVars)
202+ val maybeSplitUpdateCode = if (CodeGenerator .isValidParamLength(paramLength)) {
203+ val switchCaseFunc = ctx.freshName(" switchCaseCode" )
204+ val argList = inputVars.map { v =>
205+ s " ${CodeGenerator .typeName(v.javaType)} ${v.variableName}"
206+ }
207+ ctx.addNewFunction(switchCaseFunc,
208+ s """
209+ |private void $switchCaseFunc( ${argList.mkString(" , " )}) {
210+ | $updateCode
211+ |}
212+ """ .stripMargin)
213+
214+ s " $switchCaseFunc( ${inputVars.map(_.variableName).mkString(" , " )}); "
215+ } else {
216+ updateCode
199217 }
200- ctx.addNewFunction(switchCaseFunc,
201- s """
202- |private void $switchCaseFunc( ${argList.mkString(" , " )}) {
203- | ${updateCode.mkString(" \n " )}
204- |}
205- """ .stripMargin)
206-
207- s " $switchCaseFunc( ${inputVars.map(_.variableName).mkString(" , " )}); "
208- } else {
209- updateCode.mkString(" \n " )
218+ s """
219+ |case $row:
220+ | $maybeSplitUpdateCode
221+ | break;
222+ """ .stripMargin
223+ }
224+ } else {
225+ switchCaseExprs.map { case (row, exprCodes, _) =>
226+ s """
227+ |case $row:
228+ | ${generateUpdateCode(exprCodes)}
229+ | break;
230+ """ .stripMargin
210231 }
211-
212- s """
213- |case $row:
214- | $maybeSplitUpdateCode
215- | break;
216- """ .stripMargin
217232 }
218233
219234 val numOutput = metricTerm(ctx, " numOutputRows" )
0 commit comments