Skip to content

Commit 6bc6cc4

Browse files
author
Davies Liu
committed
infer int as LongType
1 parent 9b746f3 commit 6bc6cc4

File tree

5 files changed

+35
-11
lines changed

5 files changed

+35
-11
lines changed

python/pyspark/sql/dataframe.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -780,7 +780,7 @@ def mean(self, *cols):
780780
>>> df.groupBy().mean('age').collect()
781781
[Row(AVG(age#0)=3.5)]
782782
>>> df3.groupBy().mean('age', 'height').collect()
783-
[Row(AVG(age#4)=3.5, AVG(height#5)=82.5)]
783+
[Row(AVG(age#4L)=3.5, AVG(height#5L)=82.5)]
784784
"""
785785

786786
@df_varargs_api
@@ -791,7 +791,7 @@ def avg(self, *cols):
791791
>>> df.groupBy().avg('age').collect()
792792
[Row(AVG(age#0)=3.5)]
793793
>>> df3.groupBy().avg('age', 'height').collect()
794-
[Row(AVG(age#4)=3.5, AVG(height#5)=82.5)]
794+
[Row(AVG(age#4L)=3.5, AVG(height#5L)=82.5)]
795795
"""
796796

797797
@df_varargs_api
@@ -802,7 +802,7 @@ def max(self, *cols):
802802
>>> df.groupBy().max('age').collect()
803803
[Row(MAX(age#0)=5)]
804804
>>> df3.groupBy().max('age', 'height').collect()
805-
[Row(MAX(age#4)=5, MAX(height#5)=85)]
805+
[Row(MAX(age#4L)=5, MAX(height#5L)=85)]
806806
"""
807807

808808
@df_varargs_api
@@ -813,7 +813,7 @@ def min(self, *cols):
813813
>>> df.groupBy().min('age').collect()
814814
[Row(MIN(age#0)=2)]
815815
>>> df3.groupBy().min('age', 'height').collect()
816-
[Row(MIN(age#4)=2, MIN(height#5)=80)]
816+
[Row(MIN(age#4L)=2, MIN(height#5L)=80)]
817817
"""
818818

819819
@df_varargs_api
@@ -824,7 +824,7 @@ def sum(self, *cols):
824824
>>> df.groupBy().sum('age').collect()
825825
[Row(SUM(age#0)=7)]
826826
>>> df3.groupBy().sum('age', 'height').collect()
827-
[Row(SUM(age#4)=7, SUM(height#5)=165)]
827+
[Row(SUM(age#4L)=7, SUM(height#5L)=165)]
828828
"""
829829

830830

@@ -1028,7 +1028,9 @@ def _test():
10281028
sc = SparkContext('local[4]', 'PythonTest')
10291029
globs['sc'] = sc
10301030
globs['sqlCtx'] = SQLContext(sc)
1031-
globs['df'] = sc.parallelize([Row(name='Alice', age=2), Row(name='Bob', age=5)]).toDF()
1031+
globs['df'] = sc.parallelize([(2, 'Alice'), (5, 'Bob')])\
1032+
.toDF(StructType([StructField('age', IntegerType()),
1033+
StructField('name', StringType())]))
10321034
globs['df2'] = sc.parallelize([Row(name='Tom', height=80), Row(name='Bob', height=85)]).toDF()
10331035
globs['df3'] = sc.parallelize([Row(name='Alice', age=2, height=80),
10341036
Row(name='Bob', age=5, height=85)]).toDF()

python/pyspark/sql/tests.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636

3737
from pyspark.sql import SQLContext, HiveContext, Column
3838
from pyspark.sql.types import IntegerType, Row, ArrayType, StructType, StructField, \
39-
UserDefinedType, DoubleType, LongType, StringType
39+
UserDefinedType, DoubleType, LongType, StringType, _infer_type
4040
from pyspark.tests import ReusedPySparkTestCase
4141

4242

@@ -322,6 +322,26 @@ def test_help_command(self):
322322
pydoc.render_doc(df.foo)
323323
pydoc.render_doc(df.take(1))
324324

325+
def test_infer_long_type(self):
326+
longrow = [Row(f1='a', f2=100000000000000)]
327+
df = self.sc.parallelize(longrow).toDF()
328+
self.assertEqual(df.schema.fields[1].dataType, LongType())
329+
330+
# this saving as Parquet caused issues as well.
331+
output_dir = os.path.join(self.tempdir.name, "infer_long_type")
332+
df.saveAsParquetFile(output_dir)
333+
df1 = self.sqlCtx.parquetFile(output_dir)
334+
self.assertEquals('a', df1.first().f1)
335+
self.assertEquals(100000000000000, df1.first().f2)
336+
337+
self.assertEqual(_infer_type(1), LongType())
338+
self.assertEqual(_infer_type(2**10), LongType())
339+
self.assertEqual(_infer_type(2**20), LongType())
340+
self.assertEqual(_infer_type(2**31 - 1), LongType())
341+
self.assertEqual(_infer_type(2**31), LongType())
342+
self.assertEqual(_infer_type(2**61), LongType())
343+
self.assertEqual(_infer_type(2**71), LongType())
344+
325345

326346
class HiveContextSQLTests(ReusedPySparkTestCase):
327347

python/pyspark/sql/types.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -583,7 +583,7 @@ def _parse_datatype_json_value(json_value):
583583
_type_mappings = {
584584
type(None): NullType,
585585
bool: BooleanType,
586-
int: IntegerType,
586+
int: LongType,
587587
long: LongType,
588588
float: DoubleType,
589589
str: StringType,
@@ -933,11 +933,11 @@ def _infer_schema_type(obj, dataType):
933933
>>> schema = _parse_schema_abstract("a b c d")
934934
>>> row = (1, 1.0, "str", datetime.date(2014, 10, 10))
935935
>>> _infer_schema_type(row, schema)
936-
StructType...IntegerType...DoubleType...StringType...DateType...
936+
StructType...LongType...DoubleType...StringType...DateType...
937937
>>> row = [[1], {"key": (1, 2.0)}]
938938
>>> schema = _parse_schema_abstract("a[] b{c d}")
939939
>>> _infer_schema_type(row, schema)
940-
StructType...a,ArrayType...b,MapType(StringType,...c,IntegerType...
940+
StructType...a,ArrayType...b,MapType(StringType,...c,LongType...
941941
"""
942942
if dataType is None:
943943
return _infer_type(obj)
@@ -992,7 +992,7 @@ def _verify_type(obj, dataType):
992992
993993
>>> _verify_type(None, StructType([]))
994994
>>> _verify_type("", StringType())
995-
>>> _verify_type(0, IntegerType())
995+
>>> _verify_type(0, LongType())
996996
>>> _verify_type(range(3), ArrayType(ShortType()))
997997
>>> _verify_type(set(), ArrayType(StringType())) # doctest: +IGNORE_EXCEPTION_DETAIL
998998
Traceback (most recent call last):

sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1129,6 +1129,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
11291129
def needsConversion(dataType: DataType): Boolean = dataType match {
11301130
case ByteType => true
11311131
case ShortType => true
1132+
case LongType => true
11321133
case FloatType => true
11331134
case DateType => true
11341135
case TimestampType => true

sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,7 @@ object EvaluatePython {
186186
case (c: Int, ShortType) => c.toShort
187187
case (c: Long, ShortType) => c.toShort
188188
case (c: Long, IntegerType) => c.toInt
189+
case (c: Int, LongType) => c.toLong
189190
case (c: Double, FloatType) => c.toFloat
190191
case (c, StringType) if !c.isInstanceOf[String] => c.toString
191192

0 commit comments

Comments
 (0)