Skip to content

Commit 19812a8

Browse files
committed
Fixed the serialization issue with PortableDataStream since neither CombineFileSplit nor TaskAttemptContext implement the Serializable interface, by using ByteArrays for storing both and then recreating the objects from these bytearrays as needed.
1 parent 238c83c commit 19812a8

File tree

4 files changed

+78
-7
lines changed

4 files changed

+78
-7
lines changed

core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,8 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
289289
* @param minPartitions A suggestion value of the minimal splitting number for input data.
290290
*/
291291
def binaryArrays(path: String, minPartitions: Int = defaultMinPartitions):
292-
JavaPairRDD[String, Array[Byte]] = new JavaPairRDD(sc.binaryFiles(path,minPartitions).mapValues(_.toArray()))
292+
JavaPairRDD[String, Array[Byte]] =
293+
new JavaPairRDD(sc.binaryFiles(path,minPartitions).mapValues(_.toArray()))
293294

294295
/**
295296
* Load data from a flat binary file, assuming each record is a set of numbers

core/src/main/scala/org/apache/spark/input/RawFileInput.scala

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,15 @@ package org.apache.spark.input
2020
import scala.collection.JavaConversions._
2121
import com.google.common.io.{ByteStreams, Closeables}
2222
import org.apache.hadoop.mapreduce.InputSplit
23+
import org.apache.hadoop.conf.Configuration
2324
import org.apache.hadoop.mapreduce.lib.input.CombineFileSplit
2425
import org.apache.hadoop.mapreduce.RecordReader
2526
import org.apache.hadoop.mapreduce.TaskAttemptContext
2627
import org.apache.hadoop.fs.{FSDataInputStream, Path}
2728
import org.apache.hadoop.mapreduce.lib.input.CombineFileInputFormat
2829
import org.apache.hadoop.mapreduce.JobContext
2930
import org.apache.hadoop.mapreduce.lib.input.CombineFileRecordReader
30-
import java.io.DataInputStream
31+
import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataOutputStream, DataInputStream}
3132

3233

3334
/**
@@ -58,17 +59,50 @@ abstract class StreamFileInputFormat[T]
5859
/**
5960
* A class that allows DataStreams to be serialized and moved around by not creating them
6061
* until they need to be read
62+
* @note TaskAttemptContext is not serializable resulting in the confBytes construct
63+
* @note CombineFileSplit is not serializable resulting in the splitBytes construct
6164
*/
62-
class PortableDataStream(split: CombineFileSplit, context: TaskAttemptContext, index: Integer)
65+
class PortableDataStream(@transient isplit: CombineFileSplit, @transient context: TaskAttemptContext, index: Integer)
6366
extends Serializable {
64-
// transient forces file to be reopened after being moved (serialization)
67+
// transient forces file to be reopened after being serialization
68+
// it is also used for non-serializable classes
69+
6570
@transient
6671
private var fileIn: FSDataInputStream = null.asInstanceOf[FSDataInputStream]
6772
@transient
6873
private var isOpen = false
74+
75+
private val confBytes = {
76+
val baos = new ByteArrayOutputStream()
77+
context.getConfiguration.write(new DataOutputStream(baos))
78+
baos.toByteArray
79+
}
80+
81+
private val splitBytes = {
82+
val baos = new ByteArrayOutputStream()
83+
isplit.write(new DataOutputStream(baos))
84+
baos.toByteArray
85+
}
86+
87+
@transient
88+
private lazy val split = {
89+
val bais = new ByteArrayInputStream(splitBytes)
90+
val nsplit = new CombineFileSplit()
91+
nsplit.readFields(new DataInputStream(bais))
92+
nsplit
93+
}
94+
95+
@transient
96+
private lazy val conf = {
97+
val bais = new ByteArrayInputStream(confBytes)
98+
val nconf = new Configuration()
99+
nconf.readFields(new DataInputStream(bais))
100+
nconf
101+
}
69102
/**
70103
* Calculate the path name independently of opening the file
71104
*/
105+
@transient
72106
private lazy val path = {
73107
val pathp = split.getPath(index)
74108
pathp.toString
@@ -80,7 +114,7 @@ class PortableDataStream(split: CombineFileSplit, context: TaskAttemptContext, i
80114
def open(): FSDataInputStream = {
81115
if (!isOpen) {
82116
val pathp = split.getPath(index)
83-
val fs = pathp.getFileSystem(context.getConfiguration)
117+
val fs = pathp.getFileSystem(conf)
84118
fileIn = fs.open(pathp)
85119
isOpen=true
86120
}
@@ -207,4 +241,4 @@ abstract class BinaryRecordReader[T](
207241
parseByteArray(innerBuffer)
208242
}
209243
def parseByteArray(inArray: Array[Byte]): T
210-
}
244+
}

core/src/test/java/org/apache/spark/JavaAPISuite.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -881,7 +881,6 @@ public void binaryFilesCaching() throws Exception {
881881
readRDD.foreach(new VoidFunction<Tuple2<String,PortableDataStream>>() {
882882
@Override
883883
public void call(Tuple2<String, PortableDataStream> stringPortableDataStreamTuple2) throws Exception {
884-
stringPortableDataStreamTuple2._2().getPath();
885884
stringPortableDataStreamTuple2._2().toArray(); // force the file to read
886885
}
887886
});

core/src/test/scala/org/apache/spark/FileSuite.scala

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,43 @@ class FileSuite extends FunSuite with LocalSparkContext {
280280
assert(indata.toArray === testOutput)
281281
}
282282

283+
test("portabledatastream flatmap tests") {
284+
sc = new SparkContext("local", "test")
285+
val outFile = new File(tempDir, "record-bytestream-00000.bin")
286+
val outFileName = outFile.getAbsolutePath()
287+
288+
// create file
289+
val testOutput = Array[Byte](1,2,3,4,5,6)
290+
val numOfCopies = 3
291+
val bbuf = java.nio.ByteBuffer.wrap(testOutput)
292+
// write data to file
293+
val file = new java.io.FileOutputStream(outFile)
294+
val channel = file.getChannel
295+
channel.write(bbuf)
296+
channel.close()
297+
file.close()
298+
299+
val inRdd = sc.binaryFiles(outFileName)
300+
val mappedRdd = inRdd.map{
301+
curData: (String, PortableDataStream) =>
302+
(curData._2.getPath(),curData._2)
303+
}
304+
val copyRdd = mappedRdd.flatMap{
305+
curData: (String, PortableDataStream) =>
306+
for(i <- 1 to numOfCopies) yield (i,curData._2)
307+
}
308+
309+
val copyArr: Array[(Int, PortableDataStream)] = copyRdd.collect()
310+
311+
// Try reading the output back as an object file
312+
assert(copyArr.length == numOfCopies)
313+
copyArr.foreach{
314+
cEntry: (Int, PortableDataStream) =>
315+
assert(cEntry._2.toArray === testOutput)
316+
}
317+
318+
}
319+
283320
test("fixed record length binary file as byte array") {
284321
// a fixed length of 6 bytes
285322

0 commit comments

Comments
 (0)