Skip to content

Commit 9d8447c

Browse files
committed
apply schema provided by string of names
the type of fields will be infered automatically
1 parent f5df97f commit 9d8447c

File tree

1 file changed

+158
-10
lines changed

1 file changed

+158
-10
lines changed

python/pyspark/sql.py

Lines changed: 158 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,9 @@ class DataType(object):
4242
"""Spark SQL DataType"""
4343

4444
def __str__(self):
45+
return repr(self)
46+
47+
def __repr__(self):
4548
return self.__class__.__name__
4649

4750
def __hash__(self):
@@ -210,7 +213,7 @@ def __init__(self, keyType, valueType, valueContainsNull=True):
210213
self.valueType = valueType
211214
self.valueContainsNull = valueContainsNull
212215

213-
def __str__(self):
216+
def __repr__(self):
214217
return "MapType(%s,%s,%s)" % (self.keyType, self.valueType,
215218
str(self.valueContainsNull).lower())
216219

@@ -240,7 +243,7 @@ def __init__(self, name, dataType, nullable):
240243
self.dataType = dataType
241244
self.nullable = nullable
242245

243-
def __str__(self):
246+
def __repr__(self):
244247
return "StructField(%s,%s,%s)" % (self.name, self.dataType,
245248
str(self.nullable).lower())
246249

@@ -267,7 +270,7 @@ def __init__(self, fields):
267270
"""
268271
self.fields = fields
269272

270-
def __str__(self):
273+
def __repr__(self):
271274
return ("StructType(List(%s))" %
272275
",".join(str(field) for field in self.fields))
273276

@@ -401,7 +404,7 @@ def _parse_datatype_string(datatype_string):
401404
datetime.time: TimestampType,
402405
}
403406

404-
def _inferType(obj):
407+
def _infer_type(obj):
405408
"""Infer the DataType from obj"""
406409
if obj is None:
407410
raise ValueError("Can not infer type for None")
@@ -414,18 +417,18 @@ def _inferType(obj):
414417
if not obj:
415418
raise ValueError("Can not infer type for empty dict")
416419
key, value = obj.iteritems().next()
417-
return MapType(_inferType(key), _inferType(value), True)
420+
return MapType(_infer_type(key), _infer_type(value), True)
418421
elif isinstance(obj, (list, array.array)):
419422
if not obj:
420423
raise ValueError("Can not infer type for empty list/array")
421-
return ArrayType(_inferType(obj[0]), True)
424+
return ArrayType(_infer_type(obj[0]), True)
422425
else:
423426
try:
424-
return _inferSchema(obj)
427+
return _infer_schema(obj)
425428
except ValueError:
426429
raise ValueError("not supported type: %s" % type(obj))
427430

428-
def _inferSchema(row):
431+
def _infer_schema(row):
429432
"""Infer the schema from dict/namedtuple/object"""
430433
if isinstance(row, dict):
431434
items = sorted(row.items())
@@ -440,7 +443,7 @@ def _inferSchema(row):
440443
else:
441444
raise ValueError("Can not infer schema for type: %s" % type(row))
442445

443-
fields = [StructField(k, _inferType(v), True) for k, v in items]
446+
fields = [StructField(k, _infer_type(v), True) for k, v in items]
444447
return StructType(fields)
445448

446449
def _create_converter(obj, dataType):
@@ -494,6 +497,133 @@ def _dropSchema(rows, schema):
494497
yield converter(i)
495498

496499

500+
_BRAKETS = {'(':')', '[':']', '{':'}'}
501+
502+
def _split_schema_abstract(s):
503+
"""
504+
split the schema abstract into fields
505+
506+
>>> _split_schema_abstract("a b c")
507+
['a', 'b', 'c']
508+
>>> _split_schema_abstract("a(a b)")
509+
['a(a b)']
510+
>>> _split_schema_abstract("a b[] c{a b}")
511+
['a', 'b[]', 'c{a b}']
512+
>>> _split_schema_abstract(" ")
513+
[]
514+
"""
515+
516+
r = []
517+
w = ''
518+
brackets = []
519+
for c in s:
520+
if c == ' ' and not brackets:
521+
if w:
522+
r.append(w)
523+
w = ''
524+
else:
525+
w += c
526+
if c in _BRAKETS:
527+
brackets.append(c)
528+
elif c in _BRAKETS.values():
529+
if not brackets or c != _BRAKETS[brackets.pop()]:
530+
raise ValueError("unexpected " + c)
531+
532+
if brackets:
533+
raise ValueError("brackets not closed: %s" % brackets)
534+
if w:
535+
r.append(w)
536+
return r
537+
538+
def _parse_field_abstract(s):
539+
"""
540+
Parse a field in schema abstract
541+
542+
>>> _parse_field_abstract("a")
543+
StructField(a,None,true)
544+
>>> _parse_field_abstract("b(c d)")
545+
StructField(b,StructType(List(StructField(c,None,true),StructField(d,None,true))),true)
546+
>>> _parse_field_abstract("a[]")
547+
StructField(a,ArrayType(None,true),true)
548+
>>> _parse_field_abstract("a{[]}")
549+
StructField(a,MapType(None,ArrayType(None,true),true),true)
550+
"""
551+
if set(_BRAKETS.keys()) & set(s):
552+
idx = min((s.index(c) for c in _BRAKETS if c in s))
553+
name = s[:idx]
554+
return StructField(name, _parse_schema_abstract(s[idx:]), True)
555+
else:
556+
return StructField(s, None, True)
557+
558+
def _parse_schema_abstract(s):
559+
"""
560+
parse abstract into schema
561+
562+
>>> _parse_schema_abstract("a b c")
563+
StructType...a...b...c...
564+
>>> _parse_schema_abstract("a[b c] b{}")
565+
StructType...a,ArrayType...b...c...b,MapType...
566+
>>> _parse_schema_abstract("c{} d{a b}")
567+
StructType...c,MapType...d,MapType...a...b...
568+
>>> _parse_schema_abstract("a b(t)").fields[1]
569+
StructField(b,StructType(List(StructField(t,None,true))),true)
570+
"""
571+
s = s.strip()
572+
if not s:
573+
return
574+
575+
elif s.startswith('('):
576+
return _parse_schema_abstract(s[1:-1])
577+
578+
elif s.startswith('['):
579+
return ArrayType(_parse_schema_abstract(s[1:-1]), True)
580+
581+
elif s.startswith('{'):
582+
return MapType(None, _parse_schema_abstract(s[1:-1]))
583+
584+
parts = _split_schema_abstract(s)
585+
fields = [_parse_field_abstract(p) for p in parts]
586+
return StructType(fields)
587+
588+
def _infer_schema_type(obj, dataType):
589+
"""
590+
Fill the dataType with types infered from obj
591+
592+
>>> schema = _parse_schema_abstract("a b c")
593+
>>> row = (1, 1.0, "str")
594+
>>> _infer_schema_type(row, schema)
595+
StructType...IntegerType...DoubleType...StringType...
596+
>>> row = [[1], {"key": (1, 2.0)}]
597+
>>> schema = _parse_schema_abstract("a[] b{c d}")
598+
>>> _infer_schema_type(row, schema)
599+
StructType...a,ArrayType...b,MapType(StringType,StructType...c,IntegerType...
600+
"""
601+
if dataType is None:
602+
return _infer_type(obj)
603+
604+
if not obj:
605+
raise ValueError("Can not infer type from empty value")
606+
607+
if isinstance(dataType, ArrayType):
608+
eType = _infer_schema_type(obj[0], dataType.elementType)
609+
return ArrayType(eType, True)
610+
611+
elif isinstance(dataType, MapType):
612+
k, v = obj.iteritems().next()
613+
return MapType(_infer_type(k),
614+
_infer_schema_type(v, dataType.valueType))
615+
616+
elif isinstance(dataType, StructType):
617+
fs = dataType.fields
618+
assert len(fs) == len(obj), "Obj(%s) have different length with fields(%s)" % (obj, fs)
619+
fields = [StructField(f.name, _infer_schema_type(o, f.dataType), True)
620+
for o, f in zip(obj, fs)]
621+
return StructType(fields)
622+
623+
else:
624+
raise ValueError("Unexpected dataType: %s" % dataType)
625+
626+
497627
_cached_cls = {}
498628

499629
def _restore_object(fields, obj):
@@ -684,7 +814,7 @@ def inferSchema(self, rdd):
684814
if not first:
685815
raise ValueError("The first row in RDD is empty, can not infer schema")
686816

687-
schema = _inferSchema(first)
817+
schema = _infer_schema(first)
688818
rdd = rdd.mapPartitions(lambda rows: _dropSchema(rows, schema))
689819
return self.applySchema(rdd, schema)
690820

@@ -698,6 +828,7 @@ def applySchema(self, rdd, schema):
698828
>>> srdd2 = sqlCtx.sql("SELECT * from table1")
699829
>>> srdd2.collect()
700830
[Row(field1=1, field2=u'row1'), Row(field1=2, field2=u'row2'), Row(field1=3, field2=u'row3')]
831+
701832
>>> from datetime import datetime
702833
>>> rdd = sc.parallelize([(127, -32768, 1.0, datetime(2010, 1, 1, 1, 1, 1),
703834
... {"a": 1}, {"b": 2}, [1, 2, 3], None)])
@@ -715,7 +846,24 @@ def applySchema(self, rdd, schema):
715846
... x.byte, x.short, x.float, x.time, x.map["a"], x.struct.b, x.list, x.null))
716847
>>> srdd.collect()[0]
717848
(127, -32768, 1.0, datetime.datetime(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None)
849+
850+
>>> rdd = sc.parallelize([(127, -32768, 1.0, datetime(2010, 1, 1, 1, 1, 1),
851+
... {"a": 1}, {"b": 2}, [1, 2, 3])])
852+
>>> schema = "byte short float time map{} struct(b) list[]"
853+
>>> srdd = sqlCtx.applySchema(rdd, schema)
854+
>>> srdd.collect()
855+
[Row(byte=127, short=-32768, float=1.0, time=..., struct=Row(b=2), list=[1, 2, 3])]
856+
718857
"""
858+
859+
first = rdd.first()
860+
if not isinstance(first, (tuple, list)):
861+
raise ValueError("Can not apply schema to type: %s" % type(first))
862+
863+
if isinstance(schema, basestring):
864+
schema = _parse_schema_abstract(schema)
865+
schema = _infer_schema_type(first, schema)
866+
719867
batched = isinstance(rdd._jrdd_deserializer, BatchedSerializer)
720868
jrdd = self._pythonToJava(rdd._jrdd, batched)
721869
srdd = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), str(schema))

0 commit comments

Comments
 (0)