Skip to content

Commit d4fadbb

Browse files
committed
Add a test case for AggregatedDialect.isCascadingTruncateTable.
1 parent 803b196 commit d4fadbb

File tree

2 files changed

+27
-12
lines changed

2 files changed

+27
-12
lines changed

sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ private class AggregatedDialect(dialects: List[JdbcDialect]) extends JdbcDialect
4646
// If any dialect claims cascading truncate, this dialect is also cascading truncate.
4747
// Otherwise, if any dialect has unknown cascading truncate, this dialect is also unknown.
4848
val cascading = dialects.flatMap(_.isCascadingTruncateTable()).reduceOption(_ || _)
49-
if (cascading.get) {
49+
if (cascading.getOrElse(false)) {
5050
cascading
5151
} else {
5252
if (dialects.exists(_.isCascadingTruncateTable().isEmpty)) {

sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -747,19 +747,34 @@ class JDBCSuite extends SparkFunSuite
747747
assert(agg.getCatalystType(0, "", 1, null) === Some(LongType))
748748
assert(agg.getCatalystType(1, "", 1, null) === Some(StringType))
749749
assert(agg.isCascadingTruncateTable() === Some(true))
750+
}
750751

751-
val agg2 = new AggregatedDialect(List(new JdbcDialect {
752-
override def canHandle(url: String) : Boolean = url.startsWith("jdbc:h2:")
752+
test("Aggregated dialects: isCascadingTruncateTable") {
753+
def genDialect(cascadingTruncateTable: Option[Boolean]): JdbcDialect = new JdbcDialect {
754+
override def canHandle(url: String): Boolean = true
753755
override def getCatalystType(
754-
sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] =
755-
if (sqlType % 2 == 0) {
756-
Some(LongType)
757-
} else {
758-
None
759-
}
760-
override def isCascadingTruncateTable(): Option[Boolean] = Some(false)
761-
}, testH2Dialect))
762-
assert(agg2.isCascadingTruncateTable() === None)
756+
sqlType: Int,
757+
typeName: String,
758+
size: Int,
759+
md: MetadataBuilder): Option[DataType] = None
760+
override def isCascadingTruncateTable(): Option[Boolean] = cascadingTruncateTable
761+
}
762+
763+
val dialectCombination = Seq(
764+
List(genDialect(Some(true)), genDialect(Some(false)), genDialect(None)),
765+
List(genDialect(Some(true)), genDialect(Some(true)), genDialect(None)),
766+
List(genDialect(Some(false)), genDialect(Some(false)), genDialect(None)),
767+
List(genDialect(Some(true)), genDialect(Some(true))),
768+
List(genDialect(Some(false)), genDialect(Some(false))),
769+
List(genDialect(None), genDialect(None))
770+
)
771+
772+
val expectedCascading = Seq(Some(true), Some(true), None, Some(true), Some(false), None)
773+
774+
dialectCombination.zip(expectedCascading).foreach { case (dialects, cascading) =>
775+
val agg = new AggregatedDialect(dialects)
776+
assert(agg.isCascadingTruncateTable() === cascading)
777+
}
763778
}
764779

765780
test("DB2Dialect type mapping") {

0 commit comments

Comments
 (0)