Skip to content

Commit 2ca9741

Browse files
committed
Address comments.
1 parent c378ce2 commit 2ca9741

File tree

6 files changed

+63
-24
lines changed

6 files changed

+63
-24
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,7 @@ abstract class Expression extends TreeNode[Expression] {
101101
ctx.subExprEliminationExprs.get(this).map { subExprState =>
102102
// This expression is repeated which means that the code to evaluate it has already been added
103103
// as a function before. In that case, we just re-use it.
104-
ExprCode(ctx.registerComment(this.toString), subExprState.isNull,
105-
subExprState.value)
104+
ExprCode(ctx.registerComment(this.toString), subExprState.isNull, subExprState.value)
106105
}.getOrElse {
107106
val isNull = ctx.freshName("isNull")
108107
val value = ctx.freshName("value")

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen
1919

2020
import java.lang.{Boolean => JBool}
2121

22+
import scala.collection.mutable.ArrayBuffer
2223
import scala.language.{existentials, implicitConversions}
2324

2425
import org.apache.spark.sql.types.{BooleanType, DataType}
@@ -130,6 +131,8 @@ trait Block extends JavaCode {
130131

131132
def length: Int = toString.length
132133

134+
def nonEmpty: Boolean = toString.nonEmpty
135+
133136
// The leading prefix that should be stripped from each line.
134137
// By default we strip blanks or control characters followed by '|' from the line.
135138
var _marginChar: Option[Char] = Some('|')
@@ -167,9 +170,40 @@ object Block {
167170
case other => throw new IllegalArgumentException(
168171
s"Can not interpolate ${other.getClass.getName} into code block.")
169172
}
170-
CodeBlock(sc.parts, args)
173+
174+
val (codeParts, blockInputs) = foldLiteralArgs(sc.parts, args)
175+
CodeBlock(codeParts, blockInputs)
176+
}
177+
}
178+
}
179+
180+
// Folds eagerly the literal args into the code parts.
181+
private def foldLiteralArgs(parts: Seq[String], args: Seq[Any]): (Seq[String], Seq[Any]) = {
182+
val codeParts = ArrayBuffer.empty[String]
183+
val blockInputs = ArrayBuffer.empty[Any]
184+
185+
val strings = parts.iterator
186+
val inputs = args.iterator
187+
val buf = new StringBuilder(Block.CODE_BLOCK_BUFFER_LENGTH)
188+
189+
buf append strings.next
190+
while (strings.hasNext) {
191+
val input = inputs.next
192+
input match {
193+
case _: ExprValue | _: Block =>
194+
codeParts += buf.toString
195+
buf.clear
196+
blockInputs += input
197+
case _ =>
198+
buf append input
171199
}
200+
buf append strings.next
201+
}
202+
if (buf.nonEmpty) {
203+
codeParts += buf.toString
172204
}
205+
206+
(codeParts.toSeq, blockInputs.toSeq)
173207
}
174208
}
175209

@@ -182,11 +216,10 @@ case class CodeBlock(codeParts: Seq[String], blockInputs: Seq[Any]) extends Bloc
182216
blockInputs.flatMap {
183217
case b: Block => b.exprValues
184218
case e: ExprValue => Set(e)
185-
case _ => Set.empty[ExprValue]
186219
}.toSet
187220
}
188221

189-
override def code: String = {
222+
override lazy val code: String = {
190223
val strings = codeParts.iterator
191224
val inputs = blockInputs.iterator
192225
val buf = new StringBuilder(Block.CODE_BLOCK_BUFFER_LENGTH)
@@ -207,7 +240,7 @@ case class CodeBlock(codeParts: Seq[String], blockInputs: Seq[Any]) extends Bloc
207240

208241
case class Blocks(blocks: Seq[Block]) extends Block {
209242
override lazy val exprValues: Set[ExprValue] = blocks.flatMap(_.exprValues).toSet
210-
override def code: String = blocks.map(_.toString).mkString("\n")
243+
override lazy val code: String = blocks.map(_.toString).mkString("\n")
211244

212245
override def + (other: Block): Block = other match {
213246
case c: CodeBlock => Blocks(blocks :+ c)
@@ -217,7 +250,7 @@ case class Blocks(blocks: Seq[Block]) extends Block {
217250
}
218251

219252
object EmptyBlock extends Block with Serializable {
220-
override def code: String = ""
253+
override val code: String = ""
221254
override val exprValues: Set[ExprValue] = Set.empty
222255

223256
override def + (other: Block): Block = other

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,9 @@ case class InputFileName() extends LeafExpression with Nondeterministic {
4343

4444
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
4545
val className = InputFileBlockHolder.getClass.getName.stripSuffix("$")
46-
ev.copy(code = code"final ${CodeGenerator.javaType(dataType)} ${ev.value} = " +
47-
code"$className.getInputFilePath();", isNull = FalseLiteral)
46+
val typeDef = s"final ${CodeGenerator.javaType(dataType)}"
47+
ev.copy(code = code"$typeDef ${ev.value} = $className.getInputFilePath();",
48+
isNull = FalseLiteral)
4849
}
4950
}
5051

@@ -66,8 +67,8 @@ case class InputFileBlockStart() extends LeafExpression with Nondeterministic {
6667

6768
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
6869
val className = InputFileBlockHolder.getClass.getName.stripSuffix("$")
69-
ev.copy(code = code"final ${CodeGenerator.javaType(dataType)} ${ev.value} = " +
70-
code"$className.getStartOffset();", isNull = FalseLiteral)
70+
val typeDef = s"final ${CodeGenerator.javaType(dataType)}"
71+
ev.copy(code = code"$typeDef ${ev.value} = $className.getStartOffset();", isNull = FalseLiteral)
7172
}
7273
}
7374

@@ -89,7 +90,7 @@ case class InputFileBlockLength() extends LeafExpression with Nondeterministic {
8990

9091
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
9192
val className = InputFileBlockHolder.getClass.getName.stripSuffix("$")
92-
ev.copy(code = code"final ${CodeGenerator.javaType(dataType)} ${ev.value} = " +
93-
code"$className.getLength();", isNull = FalseLiteral)
93+
val typeDef = s"final ${CodeGenerator.javaType(dataType)}"
94+
ev.copy(code = code"$typeDef ${ev.value} = $className.getLength();", isNull = FalseLiteral)
9495
}
9596
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,20 @@ import org.apache.spark.sql.types.{BooleanType, IntegerType}
2323

2424
class CodeBlockSuite extends SparkFunSuite {
2525

26-
test("Block can interpolate string and ExprValue inputs") {
26+
test("Block interpolates string and ExprValue inputs") {
2727
val isNull = JavaCode.isNullVariable("expr1_isNull")
28-
val code = code"boolean ${isNull} = ${JavaCode.defaultLiteral(BooleanType)};"
28+
val stringLiteral = "false"
29+
val code = code"boolean $isNull = $stringLiteral;"
2930
assert(code.toString == "boolean expr1_isNull = false;")
3031
}
3132

33+
test("Literals are folded into string code parts instead of block inputs") {
34+
val value = JavaCode.variable("expr1", IntegerType)
35+
val intLiteral = 1
36+
val code = code"int $value = $intLiteral;"
37+
assert(code.asInstanceOf[CodeBlock].blockInputs === Seq(value))
38+
}
39+
3240
test("Block.stripMargin") {
3341
val isNull = JavaCode.isNullVariable("expr1_isNull")
3442
val value = JavaCode.variable("expr1", IntegerType)
@@ -92,26 +100,26 @@ class CodeBlockSuite extends SparkFunSuite {
92100
}
93101

94102
test("Throws exception when interpolating unexcepted object in code block") {
95-
val obj = TestClass(100)
103+
val obj = Tuple2(1, 1)
96104
val e = intercept[IllegalArgumentException] {
97105
code"$obj"
98106
}
99107
assert(e.getMessage().contains(s"Can not interpolate ${obj.getClass.getName}"))
100108
}
101109

102110
test("replace expr values in code block") {
103-
val statement = JavaCode.expression("1 + 1", IntegerType)
111+
val expr = JavaCode.expression("1 + 1", IntegerType)
104112
val isNull = JavaCode.isNullVariable("expr1_isNull")
105113
val exprInFunc = JavaCode.variable("expr1", IntegerType)
106114

107115
val code =
108116
code"""
109-
|callFunc(int $statement) {
117+
|callFunc(int $expr) {
110118
| boolean $isNull = false;
111-
| int $exprInFunc = $statement + 1;
119+
| int $exprInFunc = $expr + 1;
112120
|}""".stripMargin
113121

114-
val aliasedParam = JavaCode.variable("aliased", statement.javaType)
122+
val aliasedParam = JavaCode.variable("aliased", expr.javaType)
115123
val aliasedInputs = code.asInstanceOf[CodeBlock].blockInputs.map {
116124
case _: SimpleExprValue => aliasedParam
117125
case other => other
@@ -126,5 +134,3 @@ class CodeBlockSuite extends SparkFunSuite {
126134
assert(aliasedCode.toString == expected.toString)
127135
}
128136
}
129-
130-
private case class TestClass(a: Int)

sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport {
5959
}
6060
val valueVar = ctx.freshName("value")
6161
val str = s"columnVector[$columnVar, $ordinal, ${dataType.simpleString}]"
62-
val code = code"${ctx.registerComment(str)}\n" + (if (nullable) {
62+
val code = code"${ctx.registerComment(str)}" + (if (nullable) {
6363
code"""
6464
boolean $isNullVar = $columnVar.isNullAt($ordinal);
6565
$javaType $valueVar = $isNullVar ? ${CodeGenerator.defaultValue(dataType)} : ($value);

sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@ trait CodegenSupport extends SparkPlan {
260260
* them to be evaluated twice.
261261
*/
262262
protected def evaluateVariables(variables: Seq[ExprCode]): String = {
263-
val evaluate = variables.filter(_.code.toString != "").map(_.code.toString).mkString("\n")
263+
val evaluate = variables.filter(_.code.nonEmpty).map(_.code.toString).mkString("\n")
264264
variables.foreach(_.code = EmptyBlock)
265265
evaluate
266266
}

0 commit comments

Comments
 (0)