Skip to content

Commit b59a2df

Browse files
committed
Merge remote-tracking branch 'origin/SPARK-20754-trunc' into SPARK-20754-trunc
2 parents ea72fe0 + 931f07d commit b59a2df

File tree

3 files changed

+102
-28
lines changed

3 files changed

+102
-28
lines changed

python/pyspark/sql/functions.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1081,12 +1081,12 @@ def to_timestamp(col, format=None):
10811081

10821082

10831083
@since(1.5)
1084-
def trunc(data, truncParam):
1084+
def trunc(date, format):
10851085
"""
1086-
Returns date truncated to the unit specified by the truncParam or
1086+
Returns date truncated to the unit specified by the format or
10871087
numeric truncated by specified decimal places.
10881088
1089-
:param truncParam: 'year', 'YYYY', 'yy' or 'month', 'mon', 'mm' for date
1089+
:param format: 'year', 'YYYY', 'yy' or 'month', 'mon', 'mm' for date
10901090
and any int for numeric.
10911091
10921092
>>> df = spark.createDataFrame([('1997-02-28',)], ['d'])
@@ -1103,7 +1103,7 @@ def trunc(data, truncParam):
11031103
[Row(zero=1234567891.0)]
11041104
"""
11051105
sc = SparkContext._active_spark_context
1106-
return Column(sc._jvm.functions.trunc(_to_java_column(data), truncParam))
1106+
return Column(sc._jvm.functions.trunc(_to_java_column(date), format))
11071107

11081108

11091109
@since(1.5)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala

Lines changed: 55 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -274,29 +274,69 @@ case class Trunc(data: Expression, truncExpr: Expression)
274274
if (truncFormat == -1) {
275275
ev.copy(code = s"""
276276
boolean ${ev.isNull} = true;
277-
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
278-
""")
277+
int ${ev.value} = ${ctx.defaultValue(DateType)};""")
279278
} else {
280279
val d = data.genCode(ctx)
281-
ev.copy(code = s"""
280+
val dt = ctx.freshName("dt")
281+
val pre = s"""
282282
${d.code}
283283
boolean ${ev.isNull} = ${d.isNull};
284-
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
285-
if (!${ev.isNull}) {
286-
${ev.value} = $dtu.truncDate(${d.value}, $truncFormat);
287-
}""")
284+
int ${ev.value} = ${ctx.defaultValue(DateType)};"""
285+
data.dataType match {
286+
case DateType =>
287+
ev.copy(code = pre + s"""
288+
if (!${ev.isNull}) {
289+
${ev.value} = $dtu.truncDate(${d.value}, $truncFormat);
290+
}""")
291+
case TimestampType =>
292+
val ts = ctx.freshName("ts")
293+
ev.copy(code = pre + s"""
294+
String $ts = $dtu.timestampToString(${d.value});
295+
scala.Option<SQLDate> $dt = $dtu.stringToDate(UTF8String.fromString($ts));
296+
if (!${ev.isNull}) {
297+
${ev.value} = $dtu.truncDate((Integer)dt.get(), $truncFormat);
298+
}""")
299+
case StringType =>
300+
ev.copy(code = pre + s"""
301+
scala.Option<SQLDate> $dt = $dtu.stringToDate(${d.value});
302+
if (!${ev.isNull} && $dt.isDefined()) {
303+
${ev.value} = $dtu.truncDate((Integer)$dt.get(), $truncFormat);
304+
}""")
305+
}
288306
}
289307
} else {
290308
nullSafeCodeGen(ctx, ev, (dateVal, fmt) => {
291309
val truncParam = ctx.freshName("truncParam")
292-
s"""
293-
int $truncParam = $dtu.parseTruncLevel($fmt);
294-
if ($truncParam == -1) {
295-
${ev.isNull} = true;
296-
} else {
297-
${ev.value} = $dtu.truncDate($dateVal, $truncParam);
298-
}
299-
"""
310+
val dt = ctx.freshName("dt")
311+
val pre = s"int $truncParam = $dtu.parseTruncLevel($fmt);"
312+
data.dataType match {
313+
case DateType =>
314+
pre + s"""
315+
if ($truncParam == -1) {
316+
${ev.isNull} = true;
317+
} else {
318+
${ev.value} = $dtu.truncDate($dateVal, $truncParam);
319+
}"""
320+
case TimestampType =>
321+
val ts = ctx.freshName("ts")
322+
pre + s"""
323+
String $ts = $dtu.timestampToString($dateVal);
324+
scala.Option<SQLDate> $dt = $dtu.stringToDate(UTF8String.fromString($ts));
325+
if ($truncParam == -1 || $dt.isEmpty()) {
326+
${ev.isNull} = true;
327+
} else {
328+
${ev.value} = $dtu.truncDate((Integer)$dt.get(), $truncParam);
329+
}"""
330+
case StringType =>
331+
pre + s"""
332+
scala.Option<SQLDate> $dt = $dtu.stringToDate($dateVal);
333+
${ev.value} = ${ctx.defaultValue(DateType)};
334+
if ($truncParam == -1 || $dt.isEmpty()) {
335+
${ev.isNull} = true;
336+
} else {
337+
${ev.value} = $dtu.truncDate((Integer)$dt.get(), $truncParam);
338+
}"""
339+
}
300340
})
301341
}
302342
} else {

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala

Lines changed: 43 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
package org.apache.spark.sql.catalyst.expressions
1919

20-
import java.sql.Date
20+
import java.sql.{Date, Timestamp}
2121

2222
import org.apache.spark.SparkFunSuite
2323
import org.apache.spark.sql.types._
@@ -74,23 +74,57 @@ class MiscExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
7474
}
7575

7676
test("trunc date") {
77-
def test(input: Date, fmt: String, expected: Date): Unit = {
77+
def testDate(input: Date, fmt: String, expected: Date): Unit = {
7878
checkEvaluation(Trunc(Literal.create(input, DateType), Literal.create(fmt, StringType)),
7979
expected)
8080
checkEvaluation(
8181
Trunc(Literal.create(input, DateType), NonFoldableLiteral.create(fmt, StringType)),
8282
expected)
8383
}
84-
val date = Date.valueOf("2015-07-22")
84+
85+
def testString(input: String, fmt: String, expected: Date): Unit = {
86+
checkEvaluation(Trunc(Literal.create(input, StringType), Literal.create(fmt, StringType)),
87+
expected)
88+
checkEvaluation(
89+
Trunc(Literal.create(input, StringType), NonFoldableLiteral.create(fmt, StringType)),
90+
expected)
91+
}
92+
93+
def testTimestamp(input: Timestamp, fmt: String, expected: Date): Unit = {
94+
checkEvaluation(Trunc(Literal.create(input, TimestampType), Literal.create(fmt, StringType)),
95+
expected)
96+
checkEvaluation(
97+
Trunc(Literal.create(input, TimestampType), NonFoldableLiteral.create(fmt, StringType)),
98+
expected)
99+
}
100+
101+
val dateStr = "2015-07-22"
102+
val date = Date.valueOf(dateStr)
103+
val ts = new Timestamp(date.getTime)
104+
85105
Seq("yyyy", "YYYY", "year", "YEAR", "yy", "YY").foreach { fmt =>
86-
test(date, fmt, Date.valueOf("2015-01-01"))
106+
testDate(date, fmt, Date.valueOf("2015-01-01"))
107+
testString(dateStr, fmt, Date.valueOf("2015-01-01"))
108+
testTimestamp(ts, fmt, Date.valueOf("2015-01-01"))
87109
}
88110
Seq("month", "MONTH", "mon", "MON", "mm", "MM").foreach { fmt =>
89-
test(date, fmt, Date.valueOf("2015-07-01"))
111+
testDate(date, fmt, Date.valueOf("2015-07-01"))
112+
testString(dateStr, fmt, Date.valueOf("2015-07-01"))
113+
testTimestamp(ts, fmt, Date.valueOf("2015-07-01"))
90114
}
91-
test(date, "DD", null)
92-
test(date, null, null)
93-
test(null, "MON", null)
94-
test(null, null, null)
115+
testDate(date, "DD", null)
116+
testDate(date, null, null)
117+
testDate(null, "MON", null)
118+
testDate(null, null, null)
119+
120+
testString(dateStr, "DD", null)
121+
testString(dateStr, null, null)
122+
testString(null, "MON", null)
123+
testString(null, null, null)
124+
125+
testTimestamp(ts, "DD", null)
126+
testTimestamp(ts, null, null)
127+
testTimestamp(null, "MON", null)
128+
testTimestamp(null, null, null)
95129
}
96130
}

0 commit comments

Comments
 (0)