Skip to content

Commit b65606f

Browse files
committed
Add converter interface
1 parent 5757f6e commit b65606f

File tree

4 files changed

+118
-35
lines changed

4 files changed

+118
-35
lines changed

core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala

Lines changed: 86 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,84 @@
1818
package org.apache.spark.api.python
1919

2020
import org.apache.spark.rdd.RDD
21-
import org.apache.spark.SparkContext
21+
import org.apache.spark.{Logging, SparkContext}
2222
import org.apache.hadoop.conf.Configuration
2323
import org.apache.hadoop.io._
24+
import scala.util.{Failure, Success, Try}
2425

2526

27+
trait Converter {
28+
def convert(obj: Any): Any
29+
}
30+
31+
object DefaultConverter extends Converter {
32+
33+
/**
34+
* Converts a [[org.apache.hadoop.io.Writable]] to the underlying primitive, String or
35+
* object representation
36+
*/
37+
private def convertWritable(writable: Writable): Any = {
38+
import collection.JavaConversions._
39+
writable match {
40+
case iw: IntWritable => SparkContext.intWritableConverter().convert(iw)
41+
case dw: DoubleWritable => SparkContext.doubleWritableConverter().convert(dw)
42+
case lw: LongWritable => SparkContext.longWritableConverter().convert(lw)
43+
case fw: FloatWritable => SparkContext.floatWritableConverter().convert(fw)
44+
case t: Text => SparkContext.stringWritableConverter().convert(t)
45+
case bw: BooleanWritable => SparkContext.booleanWritableConverter().convert(bw)
46+
case byw: BytesWritable => SparkContext.bytesWritableConverter().convert(byw)
47+
case n: NullWritable => null
48+
case aw: ArrayWritable => aw.get().map(convertWritable(_))
49+
case mw: MapWritable => mapAsJavaMap(mw.map{ case (k, v) =>
50+
(convertWritable(k), convertWritable(v))
51+
}.toMap)
52+
case other => other
53+
}
54+
}
55+
56+
def convert(obj: Any): Any = {
57+
obj match {
58+
case writable: Writable =>
59+
convertWritable(writable)
60+
case _ =>
61+
obj
62+
}
63+
}
64+
}
65+
66+
class ConverterRegistry extends Logging {
67+
68+
var keyConverter: Converter = DefaultConverter
69+
var valueConverter: Converter = DefaultConverter
70+
71+
def convertKey(obj: Any): Any = keyConverter.convert(obj)
72+
73+
def convertValue(obj: Any): Any = valueConverter.convert(obj)
74+
75+
def registerKeyConverter(converterClass: String) = {
76+
keyConverter = register(converterClass)
77+
logInfo(s"Loaded and registered key converter ($converterClass)")
78+
}
79+
80+
def registerValueConverter(converterClass: String) = {
81+
valueConverter = register(converterClass)
82+
logInfo(s"Loaded and registered value converter ($converterClass)")
83+
}
84+
85+
private def register(converterClass: String): Converter = {
86+
Try {
87+
val converter = Class.forName(converterClass).newInstance().asInstanceOf[Converter]
88+
converter
89+
} match {
90+
case Success(s) => s
91+
case Failure(err) =>
92+
logError(s"Failed to register converter: $converterClass")
93+
throw err
94+
}
95+
96+
}
97+
}
98+
2699
/** Utilities for working with Python objects -> Hadoop-related objects */
27100
private[python] object PythonHadoopUtil {
28101

@@ -51,33 +124,18 @@ private[python] object PythonHadoopUtil {
51124
* Converts an RDD of key-value pairs, where key and/or value could be instances of
52125
* [[org.apache.hadoop.io.Writable]], into an RDD[(K, V)]
53126
*/
54-
def convertRDD[K, V](rdd: RDD[(K, V)]) = {
55-
rdd.map{
56-
case (k: Writable, v: Writable) => (convert(k).asInstanceOf[K], convert(v).asInstanceOf[V])
57-
case (k: Writable, v) => (convert(k).asInstanceOf[K], v.asInstanceOf[V])
58-
case (k, v: Writable) => (k.asInstanceOf[K], convert(v).asInstanceOf[V])
59-
case (k, v) => (k.asInstanceOf[K], v.asInstanceOf[V])
60-
}
61-
}
62-
63-
/**
64-
* Converts a [[org.apache.hadoop.io.Writable]] to the underlying primitive, String or
65-
* object representation
66-
*/
67-
private def convert(writable: Writable): Any = {
68-
import collection.JavaConversions._
69-
writable match {
70-
case iw: IntWritable => SparkContext.intWritableConverter().convert(iw)
71-
case dw: DoubleWritable => SparkContext.doubleWritableConverter().convert(dw)
72-
case lw: LongWritable => SparkContext.longWritableConverter().convert(lw)
73-
case fw: FloatWritable => SparkContext.floatWritableConverter().convert(fw)
74-
case t: Text => SparkContext.stringWritableConverter().convert(t)
75-
case bw: BooleanWritable => SparkContext.booleanWritableConverter().convert(bw)
76-
case byw: BytesWritable => SparkContext.bytesWritableConverter().convert(byw)
77-
case n: NullWritable => null
78-
case aw: ArrayWritable => aw.get().map(convert(_))
79-
case mw: MapWritable => mapAsJavaMap(mw.map{ case (k, v) => (convert(k), convert(v)) }.toMap)
80-
case other => other
127+
def convertRDD[K, V](rdd: RDD[(K, V)],
128+
keyClass: String,
129+
keyConverter: Option[String],
130+
valueClass: String,
131+
valueConverter: Option[String]) = {
132+
rdd.mapPartitions { case iter =>
133+
val registry = new ConverterRegistry
134+
keyConverter.foreach(registry.registerKeyConverter(_))
135+
valueConverter.foreach(registry.registerValueConverter(_))
136+
iter.map { case (k, v) =>
137+
(registry.convertKey(k).asInstanceOf[K], registry.convertValue(v).asInstanceOf[V])
138+
}
81139
}
82140
}
83141

core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -353,17 +353,20 @@ private[spark] object PythonRDD extends Logging {
353353
def sequenceFile[K, V](
354354
sc: JavaSparkContext,
355355
path: String,
356-
keyClass: String,
357-
valueClass: String,
356+
keyClassMaybeNull: String,
357+
valueClassMaybeNull: String,
358358
keyConverter: String,
359359
valueConverter: String,
360360
minSplits: Int) = {
361+
val keyClass = Option(keyClassMaybeNull).getOrElse("org.apache.hadoop.io.Text")
362+
val valueClass = Option(valueClassMaybeNull).getOrElse("org.apache.hadoop.io.Text")
361363
implicit val kcm = ClassTag(Class.forName(keyClass)).asInstanceOf[ClassTag[K]]
362364
implicit val vcm = ClassTag(Class.forName(valueClass)).asInstanceOf[ClassTag[V]]
363365
val kc = kcm.runtimeClass.asInstanceOf[Class[K]]
364366
val vc = vcm.runtimeClass.asInstanceOf[Class[V]]
365367
val rdd = sc.sc.sequenceFile[K, V](path, kc, vc, minSplits)
366-
val converted = PythonHadoopUtil.convertRDD[K, V](rdd)
368+
val converted = PythonHadoopUtil.convertRDD[K, V](
369+
rdd, keyClass, Option(keyConverter), valueClass, Option(valueConverter))
367370
JavaRDD.fromRDD(SerDeUtil.rddToPython(converted))
368371
}
369372

@@ -386,7 +389,8 @@ private[spark] object PythonRDD extends Logging {
386389
val rdd =
387390
newAPIHadoopRDDFromClassNames[K, V, F](sc,
388391
Some(path), inputFormatClass, keyClass, valueClass, mergedConf)
389-
val converted = PythonHadoopUtil.convertRDD[K, V](rdd)
392+
val converted = PythonHadoopUtil.convertRDD[K, V](
393+
rdd, keyClass, Option(keyConverter), valueClass, Option(valueConverter))
390394
JavaRDD.fromRDD(SerDeUtil.rddToPython(converted))
391395
}
392396

@@ -407,7 +411,8 @@ private[spark] object PythonRDD extends Logging {
407411
val rdd =
408412
newAPIHadoopRDDFromClassNames[K, V, F](sc,
409413
None, inputFormatClass, keyClass, valueClass, conf)
410-
val converted = PythonHadoopUtil.convertRDD[K, V](rdd)
414+
val converted = PythonHadoopUtil.convertRDD[K, V](
415+
rdd, keyClass, Option(keyConverter), valueClass, Option(valueConverter))
411416
JavaRDD.fromRDD(SerDeUtil.rddToPython(converted))
412417
}
413418

@@ -451,7 +456,8 @@ private[spark] object PythonRDD extends Logging {
451456
val rdd =
452457
hadoopRDDFromClassNames[K, V, F](sc,
453458
Some(path), inputFormatClass, keyClass, valueClass, mergedConf)
454-
val converted = PythonHadoopUtil.convertRDD[K, V](rdd)
459+
val converted = PythonHadoopUtil.convertRDD[K, V](
460+
rdd, keyClass, Option(keyConverter), valueClass, Option(valueConverter))
455461
JavaRDD.fromRDD(SerDeUtil.rddToPython(converted))
456462
}
457463

@@ -472,7 +478,8 @@ private[spark] object PythonRDD extends Logging {
472478
val rdd =
473479
hadoopRDDFromClassNames[K, V, F](sc,
474480
None, inputFormatClass, keyClass, valueClass, conf)
475-
val converted = PythonHadoopUtil.convertRDD[K, V](rdd)
481+
val converted = PythonHadoopUtil.convertRDD[K, V](
482+
rdd, keyClass, Option(keyConverter), valueClass, Option(valueConverter))
476483
JavaRDD.fromRDD(SerDeUtil.rddToPython(converted))
477484
}
478485

core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,14 @@ case class TestWritable(var str: String, var int: Int, var double: Double) exten
5454
}
5555
}
5656

57+
class TestConverter extends Converter {
58+
import collection.JavaConversions._
59+
override def convert(obj: Any) = {
60+
val m = obj.asInstanceOf[MapWritable]
61+
seqAsJavaList(m.keySet.map(w => w.asInstanceOf[DoubleWritable].get()).toSeq)
62+
}
63+
}
64+
5765
/**
5866
* This object contains method to generate SequenceFile test data and write it to a
5967
* given directory (probably a temp directory)

python/pyspark/tests.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,16 @@ def test_bad_inputs(self):
335335
"org.apache.hadoop.io.IntWritable",
336336
"org.apache.hadoop.io.Text"))
337337

338+
def test_converter(self):
339+
basepath = self.tempdir.name
340+
maps = sorted(self.sc.sequenceFile(
341+
basepath + "/sftestdata/sfmap/",
342+
"org.apache.hadoop.io.IntWritable",
343+
"org.apache.hadoop.io.MapWritable",
344+
valueConverter="org.apache.spark.api.python.TestConverter").collect())
345+
em = [(1, [2.0]), (1, [3.0]), (2, [1.0]), (2, [1.0]), (2, [3.0]), (3, [2.0])]
346+
self.assertEqual(maps, em)
347+
338348

339349
class TestDaemon(unittest.TestCase):
340350
def connect(self, port):

0 commit comments

Comments
 (0)