@@ -135,6 +135,8 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession {
135135 private val notIncludedMsg = " [not included in comparison]"
136136 private val clsName = this .getClass.getCanonicalName
137137
138+ protected val emptySchema = StructType (Seq .empty).catalogString
139+
138140 protected override def sparkConf : SparkConf = super .sparkConf
139141 // Fewer shuffle partitions to speed up testing.
140142 .set(SQLConf .SHUFFLE_PARTITIONS , 4 )
@@ -323,11 +325,11 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession {
323325 }
324326 // Run the SQL queries preparing them for comparison.
325327 val outputs : Seq [QueryOutput ] = queries.map { sql =>
326- val (schema, output) = getNormalizedResult(localSparkSession, sql)
328+ val (schema, output) = handleExceptions( getNormalizedResult(localSparkSession, sql) )
327329 // We might need to do some query canonicalization in the future.
328330 QueryOutput (
329331 sql = sql,
330- schema = schema.catalogString ,
332+ schema = schema,
331333 output = output.mkString(" \n " ).replaceAll(" \\ s+$" , " " ))
332334 }
333335
@@ -388,49 +390,52 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession {
388390 }
389391 }
390392
391- /** Executes a query and returns the result as (schema of the output, normalized output). */
392- private def getNormalizedResult (session : SparkSession , sql : String ): (StructType , Seq [String ]) = {
393- // Returns true if the plan is supposed to be sorted.
394- def isSorted (plan : LogicalPlan ): Boolean = plan match {
395- case _ : Join | _ : Aggregate | _ : Generate | _ : Sample | _ : Distinct => false
396- case _ : DescribeCommandBase
397- | _ : DescribeColumnCommand
398- | _ : DescribeTableStatement
399- | _ : DescribeColumnStatement => true
400- case PhysicalOperation (_, _, Sort (_, true , _)) => true
401- case _ => plan.children.iterator.exists(isSorted)
402- }
403-
393+ protected def handleExceptions (result : => (String , Seq [String ])): (String , Seq [String ]) = {
404394 try {
405- val df = session.sql(sql)
406- val schema = df.schema
407- // Get answer, but also get rid of the #1234 expression ids that show up in explain plans
408- val answer = SQLExecution .withNewExecutionId(session, df.queryExecution, Some (sql)) {
409- hiveResultString(df.queryExecution.executedPlan).map(replaceNotIncludedMsg)
410- }
411-
412- // If the output is not pre-sorted, sort it.
413- if (isSorted(df.queryExecution.analyzed)) (schema, answer) else (schema, answer.sorted)
414-
395+ result
415396 } catch {
416397 case a : AnalysisException =>
417398 // Do not output the logical plan tree which contains expression IDs.
418399 // Also implement a crude way of masking expression IDs in the error message
419400 // with a generic pattern "###".
420401 val msg = if (a.plan.nonEmpty) a.getSimpleMessage else a.getMessage
421- (StructType ( Seq .empty) , Seq (a.getClass.getName, msg.replaceAll(" #\\ d+" , " #x" )))
402+ (emptySchema , Seq (a.getClass.getName, msg.replaceAll(" #\\ d+" , " #x" )))
422403 case s : SparkException if s.getCause != null =>
423404 // For a runtime exception, it is hard to match because its message contains
424405 // information of stage, task ID, etc.
425406 // To make result matching simpler, here we match the cause of the exception if it exists.
426407 val cause = s.getCause
427- (StructType ( Seq .empty) , Seq (cause.getClass.getName, cause.getMessage))
408+ (emptySchema , Seq (cause.getClass.getName, cause.getMessage))
428409 case NonFatal (e) =>
429410 // If there is an exception, put the exception class followed by the message.
430- (StructType ( Seq .empty) , Seq (e.getClass.getName, e.getMessage))
411+ (emptySchema , Seq (e.getClass.getName, e.getMessage))
431412 }
432413 }
433414
415+ /** Executes a query and returns the result as (schema of the output, normalized output). */
416+ private def getNormalizedResult (session : SparkSession , sql : String ): (String , Seq [String ]) = {
417+ // Returns true if the plan is supposed to be sorted.
418+ def isSorted (plan : LogicalPlan ): Boolean = plan match {
419+ case _ : Join | _ : Aggregate | _ : Generate | _ : Sample | _ : Distinct => false
420+ case _ : DescribeCommandBase
421+ | _ : DescribeColumnCommand
422+ | _ : DescribeTableStatement
423+ | _ : DescribeColumnStatement => true
424+ case PhysicalOperation (_, _, Sort (_, true , _)) => true
425+ case _ => plan.children.iterator.exists(isSorted)
426+ }
427+
428+ val df = session.sql(sql)
429+ val schema = df.schema.catalogString
430+ // Get answer, but also get rid of the #1234 expression ids that show up in explain plans
431+ val answer = SQLExecution .withNewExecutionId(session, df.queryExecution, Some (sql)) {
432+ hiveResultString(df.queryExecution.executedPlan).map(replaceNotIncludedMsg)
433+ }
434+
435+ // If the output is not pre-sorted, sort it.
436+ if (isSorted(df.queryExecution.analyzed)) (schema, answer) else (schema, answer.sorted)
437+ }
438+
434439 protected def replaceNotIncludedMsg (line : String ): String = {
435440 line.replaceAll(" #\\ d+" , " #x" )
436441 .replaceAll(
0 commit comments