Skip to content

Commit 7c7caf8

Browse files
committed
support accessing SQLConf inside tasks
1 parent e39b7d0 commit 7c7caf8

File tree

12 files changed

+184
-45
lines changed

12 files changed

+184
-45
lines changed

core/src/main/scala/org/apache/spark/TaskContextImpl.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,4 +178,6 @@ private[spark] class TaskContextImpl(
178178

179179
private[spark] def fetchFailed: Option[FetchFailedException] = _fetchFailedException
180180

181+
// TODO: shall we publish it and define it in `TaskContext`?
182+
private[spark] def getLocalProperties(): Properties = localProperties
181183
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/CreateJacksonParser.scala

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -78,17 +78,4 @@ private[sql] object CreateJacksonParser extends Serializable {
7878
def inputStream(enc: String, jsonFactory: JsonFactory, is: InputStream): JsonParser = {
7979
jsonFactory.createParser(new InputStreamReader(is, enc))
8080
}
81-
82-
def internalRow(jsonFactory: JsonFactory, row: InternalRow): JsonParser = {
83-
val ba = row.getBinary(0)
84-
85-
jsonFactory.createParser(ba, 0, ba.length)
86-
}
87-
88-
def internalRow(enc: String, jsonFactory: JsonFactory, row: InternalRow): JsonParser = {
89-
val binary = row.getBinary(0)
90-
val sd = getStreamDecoder(enc, binary, binary.length)
91-
92-
jsonFactory.createParser(sd)
93-
}
9481
}
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.internal
19+
20+
import java.util.{Map => JMap}
21+
22+
import org.apache.spark.{TaskContext, TaskContextImpl}
23+
import org.apache.spark.internal.config.{ConfigEntry, ConfigProvider, ConfigReader}
24+
25+
/**
26+
* A readonly SQLConf that will be created by tasks running at the executor side. It reads the
27+
* configs from the local properties which are propagated from driver to executors.
28+
*/
29+
class ReadOnlySQLConf(context: TaskContext) extends SQLConf {
30+
31+
@transient override val settings: JMap[String, String] = {
32+
context.asInstanceOf[TaskContextImpl].getLocalProperties().asInstanceOf[JMap[String, String]]
33+
}
34+
35+
@transient override protected val reader: ConfigReader = {
36+
new ConfigReader(new TaskContextConfigProvider(context))
37+
}
38+
39+
override protected def setConfWithCheck(key: String, value: String): Unit = {
40+
throw new UnsupportedOperationException("Cannot mutate ReadOnlySQLConf.")
41+
}
42+
43+
override def unsetConf(key: String): Unit = {
44+
throw new UnsupportedOperationException("Cannot mutate ReadOnlySQLConf.")
45+
}
46+
47+
override def unsetConf(entry: ConfigEntry[_]): Unit = {
48+
throw new UnsupportedOperationException("Cannot mutate ReadOnlySQLConf.")
49+
}
50+
51+
override def clear(): Unit = {
52+
throw new UnsupportedOperationException("Cannot mutate ReadOnlySQLConf.")
53+
}
54+
}
55+
56+
class TaskContextConfigProvider(context: TaskContext) extends ConfigProvider {
57+
override def get(key: String): Option[String] = Option(context.getLocalProperty(key))
58+
}

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,12 @@ import scala.util.matching.Regex
2727

2828
import org.apache.hadoop.fs.Path
2929

30-
import org.apache.spark.{SparkContext, SparkEnv}
30+
import org.apache.spark.TaskContext
3131
import org.apache.spark.internal.Logging
3232
import org.apache.spark.internal.config._
3333
import org.apache.spark.network.util.ByteUnit
3434
import org.apache.spark.sql.catalyst.analysis.Resolver
3535
import org.apache.spark.sql.catalyst.expressions.codegen.CodeGenerator
36-
import org.apache.spark.util.Utils
3736

3837
////////////////////////////////////////////////////////////////////////////////////////////////////
3938
// This file defines the configuration options for Spark SQL.
@@ -107,7 +106,13 @@ object SQLConf {
107106
* run tests in parallel. At the time this feature was implemented, this was a no-op since we
108107
* run unit tests (that does not involve SparkSession) in serial order.
109108
*/
110-
def get: SQLConf = confGetter.get()()
109+
def get: SQLConf = {
110+
if (TaskContext.get != null) {
111+
new ReadOnlySQLConf(TaskContext.get())
112+
} else {
113+
confGetter.get()()
114+
}
115+
}
111116

112117
val OPTIMIZER_MAX_ITERATIONS = buildConf("spark.sql.optimizer.maxIterations")
113118
.internal()
@@ -1274,17 +1279,11 @@ object SQLConf {
12741279
class SQLConf extends Serializable with Logging {
12751280
import SQLConf._
12761281

1277-
if (Utils.isTesting && SparkEnv.get != null) {
1278-
// assert that we're only accessing it on the driver.
1279-
assert(SparkEnv.get.executorId == SparkContext.DRIVER_IDENTIFIER,
1280-
"SQLConf should only be created and accessed on the driver.")
1281-
}
1282-
12831282
/** Only low degree of contention is expected for conf, thus NOT using ConcurrentHashMap. */
12841283
@transient protected[spark] val settings = java.util.Collections.synchronizedMap(
12851284
new java.util.HashMap[String, String]())
12861285

1287-
@transient private val reader = new ConfigReader(settings)
1286+
@transient protected val reader = new ConfigReader(settings)
12881287

12891288
/** ************************ Spark SQL Params/Hints ******************* */
12901289

@@ -1734,7 +1733,7 @@ class SQLConf extends Serializable with Logging {
17341733
settings.containsKey(key)
17351734
}
17361735

1737-
private def setConfWithCheck(key: String, value: String): Unit = {
1736+
protected def setConfWithCheck(key: String, value: String): Unit = {
17381737
settings.put(key, value)
17391738
}
17401739

sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1607,7 +1607,9 @@ class Dataset[T] private[sql](
16071607
*/
16081608
@Experimental
16091609
@InterfaceStability.Evolving
1610-
def reduce(func: (T, T) => T): T = rdd.reduce(func)
1610+
def reduce(func: (T, T) => T): T = withNewRDDExecutionId {
1611+
rdd.reduce(func)
1612+
}
16111613

16121614
/**
16131615
* :: Experimental ::

sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,16 +68,27 @@ object SQLExecution {
6868
// sparkContext.getCallSite() would first try to pick up any call site that was previously
6969
// set, then fall back to Utils.getCallSite(); call Utils.getCallSite() directly on
7070
// streaming queries would give us call site like "run at <unknown>:0"
71-
val callSite = sparkSession.sparkContext.getCallSite()
71+
val callSite = sc.getCallSite()
7272

73-
sparkSession.sparkContext.listenerBus.post(SparkListenerSQLExecutionStart(
73+
// Set all the specified SQL configs to local properties, so that they can be available at
74+
// the executor side.
75+
val allConfigs = sparkSession.sessionState.conf.getAllConfs
76+
allConfigs.foreach {
77+
// Excludes external configs defined by users.
78+
case (key, value) if key.startsWith("spark") => sc.setLocalProperty(key, value)
79+
}
80+
81+
sc.listenerBus.post(SparkListenerSQLExecutionStart(
7482
executionId, callSite.shortForm, callSite.longForm, queryExecution.toString,
7583
SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan), System.currentTimeMillis()))
7684
try {
7785
body
7886
} finally {
79-
sparkSession.sparkContext.listenerBus.post(SparkListenerSQLExecutionEnd(
87+
sc.listenerBus.post(SparkListenerSQLExecutionEnd(
8088
executionId, System.currentTimeMillis()))
89+
allConfigs.foreach {
90+
case (key, _) => sc.setLocalProperty(key, null)
91+
}
8192
}
8293
} finally {
8394
executionIdToQueryExecution.remove(executionId)
@@ -90,12 +101,23 @@ object SQLExecution {
90101
* thread from the original one, this method can be used to connect the Spark jobs in this action
91102
* with the known executionId, e.g., `BroadcastExchangeExec.relationFuture`.
92103
*/
93-
def withExecutionId[T](sc: SparkContext, executionId: String)(body: => T): T = {
104+
def withExecutionId[T](sparkSession: SparkSession, executionId: String)(body: => T): T = {
105+
val sc = sparkSession.sparkContext
94106
val oldExecutionId = sc.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
107+
// Set all the specified SQL configs to local properties, so that they can be available at
108+
// the executor side.
109+
val allConfigs = sparkSession.sessionState.conf.getAllConfs
110+
allConfigs.foreach {
111+
// Excludes external configs defined by users.
112+
case (key, value) if key.startsWith("spark") => sc.setLocalProperty(key, value)
113+
}
95114
try {
96115
sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, executionId)
97116
body
98117
} finally {
118+
allConfigs.foreach {
119+
case (key, _) => sc.setLocalProperty(key, null)
120+
}
99121
sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, oldExecutionId)
100122
}
101123
}

sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -629,7 +629,7 @@ case class SubqueryExec(name: String, child: SparkPlan) extends UnaryExecNode {
629629
Future {
630630
// This will run in another thread. Set the execution id so that we can connect these jobs
631631
// with the correct execution.
632-
SQLExecution.withExecutionId(sparkContext, executionId) {
632+
SQLExecution.withExecutionId(sqlContext.sparkSession, executionId) {
633633
val beforeCollect = System.nanoTime()
634634
// Note that we use .executeCollect() because we don't want to convert data to Scala types
635635
val rows: Array[InternalRow] = child.executeCollect()

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -99,12 +99,7 @@ object TextInputJsonDataSource extends JsonDataSource {
9999

100100
def inferFromDataset(json: Dataset[String], parsedOptions: JSONOptions): StructType = {
101101
val sampled: Dataset[String] = JsonUtils.sample(json, parsedOptions)
102-
val rdd: RDD[InternalRow] = sampled.queryExecution.toRdd
103-
val rowParser = parsedOptions.encoding.map { enc =>
104-
CreateJacksonParser.internalRow(enc, _: JsonFactory, _: InternalRow)
105-
}.getOrElse(CreateJacksonParser.internalRow(_: JsonFactory, _: InternalRow))
106-
107-
JsonInferSchema.infer(rdd, parsedOptions, rowParser)
102+
JsonInferSchema.infer(sampled, parsedOptions, CreateJacksonParser.string)
108103
}
109104

110105
private def createBaseDataset(
@@ -165,7 +160,8 @@ object MultiLineJsonDataSource extends JsonDataSource {
165160
.map(enc => createParser(enc, _: JsonFactory, _: PortableDataStream))
166161
.getOrElse(createParser(_: JsonFactory, _: PortableDataStream))
167162

168-
JsonInferSchema.infer[PortableDataStream](sampled, parsedOptions, parser)
163+
JsonInferSchema.infer[PortableDataStream](
164+
sparkSession.createDataset(sampled)(Encoders.javaSerialization), parsedOptions, parser)
169165
}
170166

171167
private def createBaseRdd(

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import java.util.Comparator
2222
import com.fasterxml.jackson.core._
2323

2424
import org.apache.spark.SparkException
25-
import org.apache.spark.rdd.RDD
25+
import org.apache.spark.sql.{Dataset, Encoders}
2626
import org.apache.spark.sql.catalyst.analysis.TypeCoercion
2727
import org.apache.spark.sql.catalyst.json.JacksonUtils.nextUntil
2828
import org.apache.spark.sql.catalyst.json.JSONOptions
@@ -39,14 +39,14 @@ private[sql] object JsonInferSchema {
3939
* 3. Replace any remaining null fields with string, the top type
4040
*/
4141
def infer[T](
42-
json: RDD[T],
42+
json: Dataset[T],
4343
configOptions: JSONOptions,
4444
createParser: (JsonFactory, T) => JsonParser): StructType = {
4545
val parseMode = configOptions.parseMode
4646
val columnNameOfCorruptRecord = configOptions.columnNameOfCorruptRecord
4747

4848
// perform schema inference on each row and merge afterwards
49-
val rootType = json.mapPartitions { iter =>
49+
val inferredTypes = json.mapPartitions { iter =>
5050
val factory = new JsonFactory()
5151
configOptions.setJacksonOptions(factory)
5252
iter.flatMap { row =>
@@ -67,8 +67,15 @@ private[sql] object JsonInferSchema {
6767
}
6868
}
6969
}
70-
}.fold(StructType(Nil))(
71-
compatibleRootType(columnNameOfCorruptRecord, parseMode))
70+
}(Encoders.javaSerialization)
71+
72+
// TODO: use `Dataset.fold` once we have it.
73+
val rootType = try {
74+
inferredTypes.reduce(compatibleRootType(columnNameOfCorruptRecord, parseMode))
75+
} catch {
76+
case e: UnsupportedOperationException if e.getMessage == "empty collection" =>
77+
StructType(Nil)
78+
}
7279

7380
canonicalizeType(rootType) match {
7481
case Some(st: StructType) => st

sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ case class BroadcastExchangeExec(
6969
Future {
7070
// This will run in another thread. Set the execution id so that we can connect these jobs
7171
// with the correct execution.
72-
SQLExecution.withExecutionId(sparkContext, executionId) {
72+
SQLExecution.withExecutionId(sqlContext.sparkSession, executionId) {
7373
try {
7474
val beforeCollect = System.nanoTime()
7575
// Use executeCollect/executeCollectIterator to avoid conversion to Scala types

0 commit comments

Comments
 (0)