3232import datetime
3333import array
3434import math
35+ import types
3536
3637import py4j
3738try :
@@ -2277,17 +2278,18 @@ def collected(a):
22772278 return df .collect ()[0 ]["myarray" ][0 ]
22782279
22792280 # test whether pyspark can correctly handle string types
2280- string_types = set ()
2281+ string_types = []
22812282 if sys .version < "4" :
2282- string_types += set ( ['u' ])
2283+ string_types += ['u' ]
22832284 self .assertEqual (collected (array .array ('u' , ["a" ])), "a" )
22842285 if sys .version < "3" :
2285- string_types += set ( ['c' ])
2286+ string_types += ['c' ]
22862287 self .assertEqual (collected (array .array ('c' , ["a" ])), "a" )
22872288
22882289 # test whether pyspark can correctly handle int types
2289- int_types = set (['b' , 'h' , 'i' , 'l' ])
2290- for t in int_types :
2290+ int_types = ['b' , 'h' , 'i' , 'l' ]
2291+ unsigned_types = ['B' , 'H' , 'I' ]
2292+ for t in int_types + unsigned_types :
22912293 # Start from 1 and keep doubling the number until overflow.
22922294 a = array .array (t , [1 ])
22932295 while True :
@@ -2296,6 +2298,7 @@ def collected(a):
22962298 a [0 ] *= 2
22972299 except OverflowError :
22982300 break
2301+ for t in int_types :
22992302 # Start from -1 and keep doubling the number until overflow
23002303 a = array .array (t , [- 1 ])
23012304 while True :
@@ -2306,7 +2309,7 @@ def collected(a):
23062309 break
23072310
23082311 # test whether pyspark can correctly handle float types
2309- float_types = set ( ['f' , 'd' ])
2312+ float_types = ['f' , 'd' ]
23102313 for t in float_types :
23112314 # test upper bound and precision
23122315 a = array .array (t , [1.0 ])
@@ -2321,14 +2324,13 @@ def collected(a):
23212324 a [0 ] /= 2
23222325
23232326 # make sure that the test case cover all supported types
2324- supported_types = int_types + float_types + string_types
2325- self .assertEqual (supported_types , _array_type_mappings .keys )
2327+ supported_types = int_types + unsigned_types + float_types + string_types
2328+ self .assertEqual (supported_types , types . _array_type_mappings .keys )
23262329
2327- all_type_codes = set ()
23282330 if sys .version < "3" :
2329- all_type_codes + = set (['c' , 'b' , 'B' , 'u' , 'h' , 'H' , 'i' , 'I' , 'l' , 'L' , 'f' , 'd' ])
2331+ all_type_codes = set (['c' , 'b' , 'B' , 'u' , 'h' , 'H' , 'i' , 'I' , 'l' , 'L' , 'f' , 'd' ])
23302332 else :
2331- all_type_codes + = set (array .typecodes )
2333+ all_type_codes = set (array .typecodes )
23322334 unsupported_types = all_type_codes - supported_types
23332335
23342336 # test whether pyspark can correctly handle unsupported types
0 commit comments