@@ -2270,17 +2270,23 @@ def test_array_types(self):
22702270 # and Scala types.
22712271 # See: https://docs.python.org/2/library/array.html
22722272
2273- int_types = set (['b' , 'h' , 'i' , 'l' ])
2274- float_types = set (['f' , 'd' ])
2275- unsupported_types = set (array .typecodes ) - int_types - float_types
2276-
22772273 def collected (a ):
22782274 row = Row (myarray = a )
22792275 rdd = self .sc .parallelize ([row ])
22802276 df = self .spark .createDataFrame (rdd )
22812277 return df .collect ()[0 ]["myarray" ][0 ]
22822278
2279+ # test whether pyspark can correctly handle string types
2280+ string_types = set ()
2281+ if sys .version < "4" :
2282+ string_types += set (['u' ])
2283+ self .assertEqual (collected (array .array ('u' , ["a" ])), "a" )
2284+ if sys .version < "3" :
2285+ string_types += set (['c' ])
2286+ self .assertEqual (collected (array .array ('c' , ["a" ])), "a" )
2287+
22832288 # test whether pyspark can correctly handle int types
2289+ int_types = set (['b' , 'h' , 'i' , 'l' ])
22842290 for t in int_types :
22852291 # Start from 1 and keep doubling the number until overflow.
22862292 a = array .array (t , [1 ])
@@ -2300,6 +2306,7 @@ def collected(a):
23002306 break
23012307
23022308 # test whether pyspark can correctly handle float types
2309+ float_types = set (['f' , 'd' ])
23032310 for t in float_types :
23042311 # test upper bound and precision
23052312 a = array .array (t , [1.0 ])
@@ -2312,6 +2319,18 @@ def collected(a):
23122319 while a [0 ] != 0 :
23132320 self .assertEqual (collected (a ), a [0 ])
23142321 a [0 ] /= 2
2322+
2323+ # make sure that the test case cover all supported types
2324+ supported_types = int_types + float_types + string_types
2325+ self .assertEqual (supported_types , _array_type_mappings .keys )
2326+
2327+ all_type_codes = set ()
2328+ if sys .version < "3" :
2329+ all_type_codes += set ([ 'c' ,'b' ,'B' ,'u' ,'h' ,'H' ,'i' ,'I' ,'l' ,'L' ,'f' ,'d' ])
2330+ else :
2331+ all_type_codes += set (array .typecodes )
2332+ unsupported_types = all_type_codes - supported_types
2333+
23152334 # test whether pyspark can correctly handle unsupported types
23162335 for t in unsupported_types :
23172336 try :
0 commit comments