@@ -42,6 +42,9 @@ class DataType(object):
4242 """Spark SQL DataType"""
4343
4444 def __str__ (self ):
45+ return repr (self )
46+
47+ def __repr__ (self ):
4548 return self .__class__ .__name__
4649
4750 def __hash__ (self ):
@@ -210,7 +213,7 @@ def __init__(self, keyType, valueType, valueContainsNull=True):
210213 self .valueType = valueType
211214 self .valueContainsNull = valueContainsNull
212215
213- def __str__ (self ):
216+ def __repr__ (self ):
214217 return "MapType(%s,%s,%s)" % (self .keyType , self .valueType ,
215218 str (self .valueContainsNull ).lower ())
216219
@@ -240,7 +243,7 @@ def __init__(self, name, dataType, nullable):
240243 self .dataType = dataType
241244 self .nullable = nullable
242245
243- def __str__ (self ):
246+ def __repr__ (self ):
244247 return "StructField(%s,%s,%s)" % (self .name , self .dataType ,
245248 str (self .nullable ).lower ())
246249
@@ -267,7 +270,7 @@ def __init__(self, fields):
267270 """
268271 self .fields = fields
269272
270- def __str__ (self ):
273+ def __repr__ (self ):
271274 return ("StructType(List(%s))" %
272275 "," .join (str (field ) for field in self .fields ))
273276
@@ -401,7 +404,7 @@ def _parse_datatype_string(datatype_string):
401404 datetime .time : TimestampType ,
402405}
403406
404- def _inferType (obj ):
407+ def _infer_type (obj ):
405408 """Infer the DataType from obj"""
406409 if obj is None :
407410 raise ValueError ("Can not infer type for None" )
@@ -414,18 +417,18 @@ def _inferType(obj):
414417 if not obj :
415418 raise ValueError ("Can not infer type for empty dict" )
416419 key , value = obj .iteritems ().next ()
417- return MapType (_inferType (key ), _inferType (value ), True )
420+ return MapType (_infer_type (key ), _infer_type (value ), True )
418421 elif isinstance (obj , (list , array .array )):
419422 if not obj :
420423 raise ValueError ("Can not infer type for empty list/array" )
421- return ArrayType (_inferType (obj [0 ]), True )
424+ return ArrayType (_infer_type (obj [0 ]), True )
422425 else :
423426 try :
424- return _inferSchema (obj )
427+ return _infer_schema (obj )
425428 except ValueError :
426429 raise ValueError ("not supported type: %s" % type (obj ))
427430
428- def _inferSchema (row ):
431+ def _infer_schema (row ):
429432 """Infer the schema from dict/namedtuple/object"""
430433 if isinstance (row , dict ):
431434 items = sorted (row .items ())
@@ -440,7 +443,7 @@ def _inferSchema(row):
440443 else :
441444 raise ValueError ("Can not infer schema for type: %s" % type (row ))
442445
443- fields = [StructField (k , _inferType (v ), True ) for k , v in items ]
446+ fields = [StructField (k , _infer_type (v ), True ) for k , v in items ]
444447 return StructType (fields )
445448
446449def _create_converter (obj , dataType ):
@@ -494,6 +497,133 @@ def _dropSchema(rows, schema):
494497 yield converter (i )
495498
496499
500+ _BRAKETS = {'(' :')' , '[' :']' , '{' :'}' }
501+
502+ def _split_schema_abstract (s ):
503+ """
504+ split the schema abstract into fields
505+
506+ >>> _split_schema_abstract("a b c")
507+ ['a', 'b', 'c']
508+ >>> _split_schema_abstract("a(a b)")
509+ ['a(a b)']
510+ >>> _split_schema_abstract("a b[] c{a b}")
511+ ['a', 'b[]', 'c{a b}']
512+ >>> _split_schema_abstract(" ")
513+ []
514+ """
515+
516+ r = []
517+ w = ''
518+ brackets = []
519+ for c in s :
520+ if c == ' ' and not brackets :
521+ if w :
522+ r .append (w )
523+ w = ''
524+ else :
525+ w += c
526+ if c in _BRAKETS :
527+ brackets .append (c )
528+ elif c in _BRAKETS .values ():
529+ if not brackets or c != _BRAKETS [brackets .pop ()]:
530+ raise ValueError ("unexpected " + c )
531+
532+ if brackets :
533+ raise ValueError ("brackets not closed: %s" % brackets )
534+ if w :
535+ r .append (w )
536+ return r
537+
538+ def _parse_field_abstract (s ):
539+ """
540+ Parse a field in schema abstract
541+
542+ >>> _parse_field_abstract("a")
543+ StructField(a,None,true)
544+ >>> _parse_field_abstract("b(c d)")
545+ StructField(b,StructType(List(StructField(c,None,true),StructField(d,None,true))),true)
546+ >>> _parse_field_abstract("a[]")
547+ StructField(a,ArrayType(None,true),true)
548+ >>> _parse_field_abstract("a{[]}")
549+ StructField(a,MapType(None,ArrayType(None,true),true),true)
550+ """
551+ if set (_BRAKETS .keys ()) & set (s ):
552+ idx = min ((s .index (c ) for c in _BRAKETS if c in s ))
553+ name = s [:idx ]
554+ return StructField (name , _parse_schema_abstract (s [idx :]), True )
555+ else :
556+ return StructField (s , None , True )
557+
558+ def _parse_schema_abstract (s ):
559+ """
560+ parse abstract into schema
561+
562+ >>> _parse_schema_abstract("a b c")
563+ StructType...a...b...c...
564+ >>> _parse_schema_abstract("a[b c] b{}")
565+ StructType...a,ArrayType...b...c...b,MapType...
566+ >>> _parse_schema_abstract("c{} d{a b}")
567+ StructType...c,MapType...d,MapType...a...b...
568+ >>> _parse_schema_abstract("a b(t)").fields[1]
569+ StructField(b,StructType(List(StructField(t,None,true))),true)
570+ """
571+ s = s .strip ()
572+ if not s :
573+ return
574+
575+ elif s .startswith ('(' ):
576+ return _parse_schema_abstract (s [1 :- 1 ])
577+
578+ elif s .startswith ('[' ):
579+ return ArrayType (_parse_schema_abstract (s [1 :- 1 ]), True )
580+
581+ elif s .startswith ('{' ):
582+ return MapType (None , _parse_schema_abstract (s [1 :- 1 ]))
583+
584+ parts = _split_schema_abstract (s )
585+ fields = [_parse_field_abstract (p ) for p in parts ]
586+ return StructType (fields )
587+
588+ def _infer_schema_type (obj , dataType ):
589+ """
590+ Fill the dataType with types infered from obj
591+
592+ >>> schema = _parse_schema_abstract("a b c")
593+ >>> row = (1, 1.0, "str")
594+ >>> _infer_schema_type(row, schema)
595+ StructType...IntegerType...DoubleType...StringType...
596+ >>> row = [[1], {"key": (1, 2.0)}]
597+ >>> schema = _parse_schema_abstract("a[] b{c d}")
598+ >>> _infer_schema_type(row, schema)
599+ StructType...a,ArrayType...b,MapType(StringType,StructType...c,IntegerType...
600+ """
601+ if dataType is None :
602+ return _infer_type (obj )
603+
604+ if not obj :
605+ raise ValueError ("Can not infer type from empty value" )
606+
607+ if isinstance (dataType , ArrayType ):
608+ eType = _infer_schema_type (obj [0 ], dataType .elementType )
609+ return ArrayType (eType , True )
610+
611+ elif isinstance (dataType , MapType ):
612+ k , v = obj .iteritems ().next ()
613+ return MapType (_infer_type (k ),
614+ _infer_schema_type (v , dataType .valueType ))
615+
616+ elif isinstance (dataType , StructType ):
617+ fs = dataType .fields
618+ assert len (fs ) == len (obj ), "Obj(%s) have different length with fields(%s)" % (obj , fs )
619+ fields = [StructField (f .name , _infer_schema_type (o , f .dataType ), True )
620+ for o , f in zip (obj , fs )]
621+ return StructType (fields )
622+
623+ else :
624+ raise ValueError ("Unexpected dataType: %s" % dataType )
625+
626+
497627_cached_cls = {}
498628
499629def _restore_object (fields , obj ):
@@ -684,7 +814,7 @@ def inferSchema(self, rdd):
684814 if not first :
685815 raise ValueError ("The first row in RDD is empty, can not infer schema" )
686816
687- schema = _inferSchema (first )
817+ schema = _infer_schema (first )
688818 rdd = rdd .mapPartitions (lambda rows : _dropSchema (rows , schema ))
689819 return self .applySchema (rdd , schema )
690820
@@ -698,6 +828,7 @@ def applySchema(self, rdd, schema):
698828 >>> srdd2 = sqlCtx.sql("SELECT * from table1")
699829 >>> srdd2.collect()
700830 [Row(field1=1, field2=u'row1'), Row(field1=2, field2=u'row2'), Row(field1=3, field2=u'row3')]
831+
701832 >>> from datetime import datetime
702833 >>> rdd = sc.parallelize([(127, -32768, 1.0, datetime(2010, 1, 1, 1, 1, 1),
703834 ... {"a": 1}, {"b": 2}, [1, 2, 3], None)])
@@ -715,7 +846,24 @@ def applySchema(self, rdd, schema):
715846 ... x.byte, x.short, x.float, x.time, x.map["a"], x.struct.b, x.list, x.null))
716847 >>> srdd.collect()[0]
717848 (127, -32768, 1.0, datetime.datetime(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None)
849+
850+ >>> rdd = sc.parallelize([(127, -32768, 1.0, datetime(2010, 1, 1, 1, 1, 1),
851+ ... {"a": 1}, {"b": 2}, [1, 2, 3])])
852+ >>> schema = "byte short float time map{} struct(b) list[]"
853+ >>> srdd = sqlCtx.applySchema(rdd, schema)
854+ >>> srdd.collect()
855+ [Row(byte=127, short=-32768, float=1.0, time=..., struct=Row(b=2), list=[1, 2, 3])]
856+
718857 """
858+
859+ first = rdd .first ()
860+ if not isinstance (first , (tuple , list )):
861+ raise ValueError ("Can not apply schema to type: %s" % type (first ))
862+
863+ if isinstance (schema , basestring ):
864+ schema = _parse_schema_abstract (schema )
865+ schema = _infer_schema_type (first , schema )
866+
719867 batched = isinstance (rdd ._jrdd_deserializer , BatchedSerializer )
720868 jrdd = self ._pythonToJava (rdd ._jrdd , batched )
721869 srdd = self ._ssql_ctx .applySchemaToPythonRDD (jrdd .rdd (), str (schema ))
0 commit comments