Skip to content

Commit 58e1cf1

Browse files
committed
[SPARK-29359] Better exception handling in SQLQueryTestSuite and ThriftServerQueryTestSuite
1 parent 130e9ae commit 58e1cf1

File tree

2 files changed

+59
-56
lines changed

2 files changed

+59
-56
lines changed

sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala

Lines changed: 33 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,8 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession {
135135
private val notIncludedMsg = "[not included in comparison]"
136136
private val clsName = this.getClass.getCanonicalName
137137

138+
protected val emptySchema = StructType(Seq.empty).catalogString
139+
138140
protected override def sparkConf: SparkConf = super.sparkConf
139141
// Fewer shuffle partitions to speed up testing.
140142
.set(SQLConf.SHUFFLE_PARTITIONS, 4)
@@ -323,11 +325,11 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession {
323325
}
324326
// Run the SQL queries preparing them for comparison.
325327
val outputs: Seq[QueryOutput] = queries.map { sql =>
326-
val (schema, output) = getNormalizedResult(localSparkSession, sql)
328+
val (schema, output) = handleExceptions(getNormalizedResult(localSparkSession, sql))
327329
// We might need to do some query canonicalization in the future.
328330
QueryOutput(
329331
sql = sql,
330-
schema = schema.catalogString,
332+
schema = schema,
331333
output = output.mkString("\n").replaceAll("\\s+$", ""))
332334
}
333335

@@ -388,49 +390,52 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession {
388390
}
389391
}
390392

391-
/** Executes a query and returns the result as (schema of the output, normalized output). */
392-
private def getNormalizedResult(session: SparkSession, sql: String): (StructType, Seq[String]) = {
393-
// Returns true if the plan is supposed to be sorted.
394-
def isSorted(plan: LogicalPlan): Boolean = plan match {
395-
case _: Join | _: Aggregate | _: Generate | _: Sample | _: Distinct => false
396-
case _: DescribeCommandBase
397-
| _: DescribeColumnCommand
398-
| _: DescribeTableStatement
399-
| _: DescribeColumnStatement => true
400-
case PhysicalOperation(_, _, Sort(_, true, _)) => true
401-
case _ => plan.children.iterator.exists(isSorted)
402-
}
403-
393+
protected def handleExceptions(result: => (String, Seq[String])): (String, Seq[String]) = {
404394
try {
405-
val df = session.sql(sql)
406-
val schema = df.schema
407-
// Get answer, but also get rid of the #1234 expression ids that show up in explain plans
408-
val answer = SQLExecution.withNewExecutionId(session, df.queryExecution, Some(sql)) {
409-
hiveResultString(df.queryExecution.executedPlan).map(replaceNotIncludedMsg)
410-
}
411-
412-
// If the output is not pre-sorted, sort it.
413-
if (isSorted(df.queryExecution.analyzed)) (schema, answer) else (schema, answer.sorted)
414-
395+
result
415396
} catch {
416397
case a: AnalysisException =>
417398
// Do not output the logical plan tree which contains expression IDs.
418399
// Also implement a crude way of masking expression IDs in the error message
419400
// with a generic pattern "###".
420401
val msg = if (a.plan.nonEmpty) a.getSimpleMessage else a.getMessage
421-
(StructType(Seq.empty), Seq(a.getClass.getName, msg.replaceAll("#\\d+", "#x")))
402+
(emptySchema, Seq(a.getClass.getName, msg.replaceAll("#\\d+", "#x")))
422403
case s: SparkException if s.getCause != null =>
423404
// For a runtime exception, it is hard to match because its message contains
424405
// information of stage, task ID, etc.
425406
// To make result matching simpler, here we match the cause of the exception if it exists.
426407
val cause = s.getCause
427-
(StructType(Seq.empty), Seq(cause.getClass.getName, cause.getMessage))
408+
(emptySchema, Seq(cause.getClass.getName, cause.getMessage))
428409
case NonFatal(e) =>
429410
// If there is an exception, put the exception class followed by the message.
430-
(StructType(Seq.empty), Seq(e.getClass.getName, e.getMessage))
411+
(emptySchema, Seq(e.getClass.getName, e.getMessage))
431412
}
432413
}
433414

415+
/** Executes a query and returns the result as (schema of the output, normalized output). */
416+
private def getNormalizedResult(session: SparkSession, sql: String): (String, Seq[String]) = {
417+
// Returns true if the plan is supposed to be sorted.
418+
def isSorted(plan: LogicalPlan): Boolean = plan match {
419+
case _: Join | _: Aggregate | _: Generate | _: Sample | _: Distinct => false
420+
case _: DescribeCommandBase
421+
| _: DescribeColumnCommand
422+
| _: DescribeTableStatement
423+
| _: DescribeColumnStatement => true
424+
case PhysicalOperation(_, _, Sort(_, true, _)) => true
425+
case _ => plan.children.iterator.exists(isSorted)
426+
}
427+
428+
val df = session.sql(sql)
429+
val schema = df.schema.catalogString
430+
// Get answer, but also get rid of the #1234 expression ids that show up in explain plans
431+
val answer = SQLExecution.withNewExecutionId(session, df.queryExecution, Some(sql)) {
432+
hiveResultString(df.queryExecution.executedPlan).map(replaceNotIncludedMsg)
433+
}
434+
435+
// If the output is not pre-sorted, sort it.
436+
if (isSorted(df.queryExecution.analyzed)) (schema, answer) else (schema, answer.sorted)
437+
}
438+
434439
protected def replaceNotIncludedMsg(line: String): String = {
435440
line.replaceAll("#\\d+", "#x")
436441
.replaceAll(

sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerQueryTestSuite.scala

Lines changed: 26 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ import org.apache.commons.lang3.exception.ExceptionUtils
2828
import org.apache.hadoop.hive.conf.HiveConf.ConfVars
2929

3030
import org.apache.spark.{SparkConf, SparkException}
31-
import org.apache.spark.sql.{AnalysisException, SQLQueryTestSuite}
31+
import org.apache.spark.sql.SQLQueryTestSuite
3232
import org.apache.spark.sql.catalyst.analysis.NoSuchTableException
3333
import org.apache.spark.sql.catalyst.util.fileToString
3434
import org.apache.spark.sql.execution.HiveResult
@@ -123,7 +123,7 @@ class ThriftServerQueryTestSuite extends SQLQueryTestSuite {
123123

124124
// Run the SQL queries preparing them for comparison.
125125
val outputs: Seq[QueryOutput] = queries.map { sql =>
126-
val output = getNormalizedResult(statement, sql)
126+
val (_, output) = handleExceptions(getNormalizedResult(statement, sql))
127127
// We might need to do some query canonicalization in the future.
128128
QueryOutput(
129129
sql = sql,
@@ -142,8 +142,9 @@ class ThriftServerQueryTestSuite extends SQLQueryTestSuite {
142142
"Try regenerate the result files.")
143143
Seq.tabulate(outputs.size) { i =>
144144
val sql = segments(i * 3 + 1).trim
145+
val schema = segments(i * 3 + 2).trim
145146
val originalOut = segments(i * 3 + 3)
146-
val output = if (isNeedSort(sql)) {
147+
val output = if (schema != emptySchema && isNeedSort(sql)) {
147148
originalOut.split("\n").sorted.mkString("\n")
148149
} else {
149150
originalOut
@@ -254,32 +255,29 @@ class ThriftServerQueryTestSuite extends SQLQueryTestSuite {
254255
}
255256
}
256257

257-
private def getNormalizedResult(statement: Statement, sql: String): Seq[String] = {
258-
try {
259-
val rs = statement.executeQuery(sql)
260-
val cols = rs.getMetaData.getColumnCount
261-
val buildStr = () => (for (i <- 1 to cols) yield {
262-
getHiveResult(rs.getObject(i))
263-
}).mkString("\t")
264-
265-
val answer = Iterator.continually(rs.next()).takeWhile(identity).map(_ => buildStr()).toSeq
266-
.map(replaceNotIncludedMsg)
267-
if (isNeedSort(sql)) {
268-
answer.sorted
269-
} else {
270-
answer
258+
override def handleExceptions(result: => (String, Seq[String])): (String, Seq[String]) = {
259+
super.handleExceptions {
260+
try {
261+
result
262+
} catch {
263+
case NonFatal(e) => throw ExceptionUtils.getRootCause(e)
271264
}
272-
} catch {
273-
case a: AnalysisException =>
274-
// Do not output the logical plan tree which contains expression IDs.
275-
// Also implement a crude way of masking expression IDs in the error message
276-
// with a generic pattern "###".
277-
val msg = if (a.plan.nonEmpty) a.getSimpleMessage else a.getMessage
278-
Seq(a.getClass.getName, msg.replaceAll("#\\d+", "#x")).sorted
279-
case NonFatal(e) =>
280-
val rootCause = ExceptionUtils.getRootCause(e)
281-
// If there is an exception, put the exception class followed by the message.
282-
Seq(rootCause.getClass.getName, rootCause.getMessage)
265+
}
266+
}
267+
268+
private def getNormalizedResult(statement: Statement, sql: String): (String, Seq[String]) = {
269+
val rs = statement.executeQuery(sql)
270+
val cols = rs.getMetaData.getColumnCount
271+
val buildStr = () => (for (i <- 1 to cols) yield {
272+
getHiveResult(rs.getObject(i))
273+
}).mkString("\t")
274+
275+
val answer = Iterator.continually(rs.next()).takeWhile(identity).map(_ => buildStr()).toSeq
276+
.map(replaceNotIncludedMsg)
277+
if (isNeedSort(sql)) {
278+
("", answer.sorted)
279+
} else {
280+
("", answer)
283281
}
284282
}
285283

0 commit comments

Comments
 (0)