Skip to content

Commit 76dceb2

Browse files
committed
Use TextFileFormat in implementation of JsonFileFormat
1 parent fb9beda commit 76dceb2

File tree

5 files changed

+122
-94
lines changed

5 files changed

+122
-94
lines changed

sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ import org.apache.spark.sql.execution.command.DDLUtils
3232
import org.apache.spark.sql.execution.datasources.csv._
3333
import org.apache.spark.sql.execution.datasources.DataSource
3434
import org.apache.spark.sql.execution.datasources.jdbc._
35-
import org.apache.spark.sql.execution.datasources.json.JsonInferSchema
35+
import org.apache.spark.sql.execution.datasources.json.TextInputJsonDataSource
3636
import org.apache.spark.sql.types.{StringType, StructType}
3737
import org.apache.spark.unsafe.types.UTF8String
3838

@@ -360,17 +360,14 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
360360
extraOptions.toMap,
361361
sparkSession.sessionState.conf.sessionLocalTimeZone,
362362
sparkSession.sessionState.conf.columnNameOfCorruptRecord)
363-
val createParser = CreateJacksonParser.string _
364363

365364
val schema = userSpecifiedSchema.getOrElse {
366-
JsonInferSchema.infer(
367-
jsonDataset.rdd,
368-
parsedOptions,
369-
createParser)
365+
TextInputJsonDataSource.inferFromDataset(jsonDataset, parsedOptions)
370366
}
371367

372368
verifyColumnNameOfCorruptRecord(schema, parsedOptions.columnNameOfCorruptRecord)
373369

370+
val createParser = CreateJacksonParser.string _
374371
val parsed = jsonDataset.rdd.mapPartitions { iter =>
375372
val parser = new JacksonParser(schema, parsedOptions)
376373
iter.flatMap(parser.parse(_, createParser, UTF8String.fromString))

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala

Lines changed: 66 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -17,32 +17,30 @@
1717

1818
package org.apache.spark.sql.execution.datasources.json
1919

20-
import scala.reflect.ClassTag
21-
2220
import com.fasterxml.jackson.core.{JsonFactory, JsonParser}
2321
import com.google.common.io.ByteStreams
2422
import org.apache.hadoop.conf.Configuration
2523
import org.apache.hadoop.fs.FileStatus
26-
import org.apache.hadoop.io.{LongWritable, Text}
24+
import org.apache.hadoop.io.Text
2725
import org.apache.hadoop.mapreduce.Job
28-
import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat, TextInputFormat}
26+
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat
2927

3028
import org.apache.spark.TaskContext
3129
import org.apache.spark.input.{PortableDataStream, StreamInputFormat}
3230
import org.apache.spark.rdd.{BinaryFileRDD, RDD}
33-
import org.apache.spark.sql.{AnalysisException, SparkSession}
31+
import org.apache.spark.sql.{AnalysisException, Dataset, Encoders, SparkSession}
3432
import org.apache.spark.sql.catalyst.InternalRow
3533
import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions}
36-
import org.apache.spark.sql.execution.datasources.{CodecStreams, HadoopFileLinesReader, PartitionedFile}
34+
import org.apache.spark.sql.execution.datasources.{CodecStreams, DataSource, HadoopFileLinesReader, PartitionedFile}
35+
import org.apache.spark.sql.execution.datasources.text.TextFileFormat
3736
import org.apache.spark.sql.types.StructType
3837
import org.apache.spark.unsafe.types.UTF8String
3938
import org.apache.spark.util.Utils
4039

4140
/**
4241
* Common functions for parsing JSON files
43-
* @tparam T A datatype containing the unparsed JSON, such as [[Text]] or [[String]]
4442
*/
45-
abstract class JsonDataSource[T] extends Serializable {
43+
abstract class JsonDataSource extends Serializable {
4644
def isSplitable: Boolean
4745

4846
/**
@@ -53,35 +51,24 @@ abstract class JsonDataSource[T] extends Serializable {
5351
file: PartitionedFile,
5452
parser: JacksonParser): Iterator[InternalRow]
5553

56-
/**
57-
* Create an [[RDD]] that handles the preliminary parsing of [[T]] records
58-
*/
59-
protected def createBaseRdd(
60-
sparkSession: SparkSession,
61-
inputPaths: Seq[FileStatus]): RDD[T]
62-
63-
/**
64-
* A generic wrapper to invoke the correct [[JsonFactory]] method to allocate a [[JsonParser]]
65-
* for an instance of [[T]]
66-
*/
67-
def createParser(jsonFactory: JsonFactory, value: T): JsonParser
68-
69-
final def infer(
54+
final def inferSchema(
7055
sparkSession: SparkSession,
7156
inputPaths: Seq[FileStatus],
7257
parsedOptions: JSONOptions): Option[StructType] = {
7358
if (inputPaths.nonEmpty) {
74-
val jsonSchema = JsonInferSchema.infer(
75-
createBaseRdd(sparkSession, inputPaths),
76-
parsedOptions,
77-
createParser)
59+
val jsonSchema = infer(sparkSession, inputPaths, parsedOptions)
7860
checkConstraints(jsonSchema)
7961
Some(jsonSchema)
8062
} else {
8163
None
8264
}
8365
}
8466

67+
protected def infer(
68+
sparkSession: SparkSession,
69+
inputPaths: Seq[FileStatus],
70+
parsedOptions: JSONOptions): StructType
71+
8572
/** Constraints to be imposed on schema to be stored. */
8673
private def checkConstraints(schema: StructType): Unit = {
8774
if (schema.fieldNames.length != schema.fieldNames.distinct.length) {
@@ -95,53 +82,46 @@ abstract class JsonDataSource[T] extends Serializable {
9582
}
9683

9784
object JsonDataSource {
98-
def apply(options: JSONOptions): JsonDataSource[_] = {
85+
def apply(options: JSONOptions): JsonDataSource = {
9986
if (options.wholeFile) {
10087
WholeFileJsonDataSource
10188
} else {
10289
TextInputJsonDataSource
10390
}
10491
}
105-
106-
/**
107-
* Create a new [[RDD]] via the supplied callback if there is at least one file to process,
108-
* otherwise an [[org.apache.spark.rdd.EmptyRDD]] will be returned.
109-
*/
110-
def createBaseRdd[T : ClassTag](
111-
sparkSession: SparkSession,
112-
inputPaths: Seq[FileStatus])(
113-
fn: (Configuration, String) => RDD[T]): RDD[T] = {
114-
val paths = inputPaths.map(_.getPath)
115-
116-
if (paths.nonEmpty) {
117-
val job = Job.getInstance(sparkSession.sessionState.newHadoopConf())
118-
FileInputFormat.setInputPaths(job, paths: _*)
119-
fn(job.getConfiguration, paths.mkString(","))
120-
} else {
121-
sparkSession.sparkContext.emptyRDD[T]
122-
}
123-
}
12492
}
12593

126-
object TextInputJsonDataSource extends JsonDataSource[Text] {
94+
object TextInputJsonDataSource extends JsonDataSource {
12795
override val isSplitable: Boolean = {
12896
// splittable if the underlying source is
12997
true
13098
}
13199

132-
override protected def createBaseRdd(
100+
override def infer(
133101
sparkSession: SparkSession,
134-
inputPaths: Seq[FileStatus]): RDD[Text] = {
135-
JsonDataSource.createBaseRdd(sparkSession, inputPaths) {
136-
case (conf, name) =>
137-
sparkSession.sparkContext.newAPIHadoopRDD(
138-
conf,
139-
classOf[TextInputFormat],
140-
classOf[LongWritable],
141-
classOf[Text])
142-
.setName(s"JsonLines: $name")
143-
.values // get the text column
144-
}
102+
inputPaths: Seq[FileStatus],
103+
parsedOptions: JSONOptions): StructType = {
104+
val json: Dataset[String] = createBaseDataset(sparkSession, inputPaths)
105+
inferFromDataset(json, parsedOptions)
106+
}
107+
108+
def inferFromDataset(json: Dataset[String], parsedOptions: JSONOptions): StructType = {
109+
val sampled: Dataset[String] = JsonUtils.sample(json, parsedOptions)
110+
val rdd: RDD[UTF8String] = sampled.queryExecution.toRdd.map(_.getUTF8String(0))
111+
JsonInferSchema.infer(rdd, parsedOptions, CreateJacksonParser.utf8String)
112+
}
113+
114+
private def createBaseDataset(
115+
sparkSession: SparkSession,
116+
inputPaths: Seq[FileStatus]): Dataset[String] = {
117+
val paths = inputPaths.map(_.getPath.toString)
118+
sparkSession.baseRelationToDataFrame(
119+
DataSource.apply(
120+
sparkSession,
121+
paths = paths,
122+
className = classOf[TextFileFormat].getName
123+
).resolveRelation(checkFilesExist = false))
124+
.select("value").as(Encoders.STRING)
145125
}
146126

147127
override def readFile(
@@ -150,41 +130,48 @@ object TextInputJsonDataSource extends JsonDataSource[Text] {
150130
parser: JacksonParser): Iterator[InternalRow] = {
151131
val linesReader = new HadoopFileLinesReader(file, conf)
152132
Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close()))
153-
linesReader.flatMap(parser.parse(_, createParser, textToUTF8String))
133+
linesReader.flatMap(parser.parse(_, CreateJacksonParser.text, textToUTF8String))
154134
}
155135

156136
private def textToUTF8String(value: Text): UTF8String = {
157137
UTF8String.fromBytes(value.getBytes, 0, value.getLength)
158138
}
159-
160-
override def createParser(jsonFactory: JsonFactory, value: Text): JsonParser = {
161-
CreateJacksonParser.text(jsonFactory, value)
162-
}
163139
}
164140

165-
object WholeFileJsonDataSource extends JsonDataSource[PortableDataStream] {
141+
object WholeFileJsonDataSource extends JsonDataSource {
166142
override val isSplitable: Boolean = {
167143
false
168144
}
169145

170-
override protected def createBaseRdd(
146+
override def infer(
147+
sparkSession: SparkSession,
148+
inputPaths: Seq[FileStatus],
149+
parsedOptions: JSONOptions): StructType = {
150+
val json: RDD[PortableDataStream] = createBaseRdd(sparkSession, inputPaths)
151+
val sampled: RDD[PortableDataStream] = JsonUtils.sample(json, parsedOptions)
152+
JsonInferSchema.infer(sampled, parsedOptions, createParser)
153+
}
154+
155+
private def createBaseRdd(
171156
sparkSession: SparkSession,
172157
inputPaths: Seq[FileStatus]): RDD[PortableDataStream] = {
173-
JsonDataSource.createBaseRdd(sparkSession, inputPaths) {
174-
case (conf, name) =>
175-
new BinaryFileRDD(
176-
sparkSession.sparkContext,
177-
classOf[StreamInputFormat],
178-
classOf[String],
179-
classOf[PortableDataStream],
180-
conf,
181-
sparkSession.sparkContext.defaultMinPartitions)
182-
.setName(s"JsonFile: $name")
183-
.values
184-
}
158+
val paths = inputPaths.map(_.getPath)
159+
val job = Job.getInstance(sparkSession.sessionState.newHadoopConf())
160+
val conf = job.getConfiguration
161+
val name = paths.mkString(",")
162+
FileInputFormat.setInputPaths(job, paths: _*)
163+
new BinaryFileRDD(
164+
sparkSession.sparkContext,
165+
classOf[StreamInputFormat],
166+
classOf[String],
167+
classOf[PortableDataStream],
168+
conf,
169+
sparkSession.sparkContext.defaultMinPartitions)
170+
.setName(s"JsonFile: $name")
171+
.values
185172
}
186173

187-
override def createParser(jsonFactory: JsonFactory, record: PortableDataStream): JsonParser = {
174+
private def createParser(jsonFactory: JsonFactory, record: PortableDataStream): JsonParser = {
188175
CreateJacksonParser.inputStream(
189176
jsonFactory,
190177
CodecStreams.createInputStreamWithCloseResource(record.getConfiguration, record.getPath()))

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister {
5454
options,
5555
sparkSession.sessionState.conf.sessionLocalTimeZone,
5656
sparkSession.sessionState.conf.columnNameOfCorruptRecord)
57-
JsonDataSource(parsedOptions).infer(
57+
JsonDataSource(parsedOptions).inferSchema(
5858
sparkSession, files, parsedOptions)
5959
}
6060

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -40,18 +40,11 @@ private[sql] object JsonInferSchema {
4040
json: RDD[T],
4141
configOptions: JSONOptions,
4242
createParser: (JsonFactory, T) => JsonParser): StructType = {
43-
require(configOptions.samplingRatio > 0,
44-
s"samplingRatio (${configOptions.samplingRatio}) should be greater than 0")
4543
val shouldHandleCorruptRecord = configOptions.permissive
4644
val columnNameOfCorruptRecord = configOptions.columnNameOfCorruptRecord
47-
val schemaData = if (configOptions.samplingRatio > 0.99) {
48-
json
49-
} else {
50-
json.sample(withReplacement = false, configOptions.samplingRatio, 1)
51-
}
5245

5346
// perform schema inference on each row and merge afterwards
54-
val rootType = schemaData.mapPartitions { iter =>
47+
val rootType = json.mapPartitions { iter =>
5548
val factory = new JsonFactory()
5649
configOptions.setJacksonOptions(factory)
5750
iter.flatMap { row =>
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution.datasources.json
19+
20+
import org.apache.spark.input.PortableDataStream
21+
import org.apache.spark.rdd.RDD
22+
import org.apache.spark.sql.Dataset
23+
import org.apache.spark.sql.catalyst.json.JSONOptions
24+
25+
object JsonUtils {
26+
/**
27+
* Sample JSON dataset as configured by `samplingRatio`.
28+
*/
29+
def sample(json: Dataset[String], options: JSONOptions): Dataset[String] = {
30+
require(options.samplingRatio > 0,
31+
s"samplingRatio (${options.samplingRatio}) should be greater than 0")
32+
if (options.samplingRatio > 0.99) {
33+
json
34+
} else {
35+
json.sample(withReplacement = false, options.samplingRatio, 1)
36+
}
37+
}
38+
39+
/**
40+
* Sample JSON RDD as configured by `samplingRatio`.
41+
*/
42+
def sample(json: RDD[PortableDataStream], options: JSONOptions): RDD[PortableDataStream] = {
43+
require(options.samplingRatio > 0,
44+
s"samplingRatio (${options.samplingRatio}) should be greater than 0")
45+
if (options.samplingRatio > 0.99) {
46+
json
47+
} else {
48+
json.sample(withReplacement = false, options.samplingRatio, 1)
49+
}
50+
}
51+
}

0 commit comments

Comments
 (0)