Skip to content

Commit dd6405c

Browse files
committed
Fix Dataset.dropduplicates.
1 parent 7e16c94 commit dd6405c

File tree

2 files changed

+27
-4
lines changed

2 files changed

+27
-4
lines changed

sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1878,17 +1878,23 @@ class Dataset[T] private[sql](
18781878
def dropDuplicates(colNames: Seq[String]): Dataset[T] = withTypedPlan {
18791879
val resolver = sparkSession.sessionState.analyzer.resolver
18801880
val allColumns = queryExecution.analyzed.output
1881-
val groupCols = colNames.map { colName =>
1882-
allColumns.find(col => resolver(col.name, colName)).getOrElse(
1881+
val groupCols = colNames.flatMap { colName =>
1882+
// It is possibly there are more than one columns with the same name,
1883+
// so we call filter instead of find.
1884+
val cols = allColumns.filter(col => resolver(col.name, colName))
1885+
if (cols.isEmpty) {
18831886
throw new AnalysisException(
1884-
s"""Cannot resolve column name "$colName" among (${schema.fieldNames.mkString(", ")})"""))
1887+
s"""Cannot resolve column name "$colName" among (${schema.fieldNames.mkString(", ")})""")
1888+
}
1889+
cols
18851890
}
18861891
val groupColExprIds = groupCols.map(_.exprId)
18871892
val aggCols = logicalPlan.output.map { attr =>
18881893
if (groupColExprIds.contains(attr.exprId)) {
18891894
attr
18901895
} else {
1891-
Alias(new First(attr).toAggregateExpression(), attr.name)()
1896+
// We should keep the original exprId of the attribute.
1897+
Alias(new First(attr).toAggregateExpression(), attr.name)(exprId = attr.exprId)
18921898
}
18931899
}
18941900
Aggregate(groupCols, aggCols, logicalPlan)

sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -872,6 +872,23 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
872872
("a", 1), ("a", 2), ("b", 1))
873873
}
874874

875+
test("dropDuplicates: columns with same column name") {
876+
val ds1 = Seq(("a", 1), ("a", 2), ("b", 1), ("a", 1)).toDS()
877+
val ds2 = Seq(("a", 1), ("a", 2), ("b", 1), ("a", 1)).toDS()
878+
// The dataset joined has two columns of the same name "_2".
879+
val joined = ds1.join(ds2, "_1").select(ds1("_2").as[Int], ds2("_2").as[Int])
880+
checkDataset(
881+
joined.dropDuplicates(),
882+
(1, 2), (1, 1), (2, 1), (2, 2))
883+
}
884+
885+
test("dropDuplicates should not change child plan output") {
886+
val ds = Seq(("a", 1), ("a", 2), ("b", 1), ("a", 1)).toDS()
887+
checkDataset(
888+
ds.dropDuplicates("_1").select(ds("_1").as[String], ds("_2").as[Int]),
889+
("a", 1), ("b", 1))
890+
}
891+
875892
test("SPARK-16097: Encoders.tuple should handle null object correctly") {
876893
val enc = Encoders.tuple(Encoders.tuple(Encoders.STRING, Encoders.STRING), Encoders.STRING)
877894
val data = Seq((("a", "b"), "c"), (null, "d"))

0 commit comments

Comments
 (0)