Skip to content

Commit 197e806

Browse files
Support the multiple sessions for ThriftServer
1 parent 7683982 commit 197e806

File tree

19 files changed

+254
-72
lines changed

19 files changed

+254
-72
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,8 @@ trait Catalog {
7777
}
7878

7979
class SimpleCatalog(val caseSensitive: Boolean) extends Catalog {
80-
val tables = new mutable.HashMap[String, LogicalPlan]()
80+
import scala.collection.mutable.SynchronizedMap
81+
val tables = new mutable.HashMap[String, LogicalPlan]() with SynchronizedMap[String, LogicalPlan]
8182

8283
override def registerTable(
8384
tableIdentifier: Seq[String],
@@ -134,9 +135,11 @@ class SimpleCatalog(val caseSensitive: Boolean) extends Catalog {
134135
* lost when the JVM exits.
135136
*/
136137
trait OverrideCatalog extends Catalog {
138+
import scala.collection.mutable.SynchronizedMap
137139

138140
// TODO: This doesn't work when the database changes...
139141
val overrides = new mutable.HashMap[(Option[String],String), LogicalPlan]()
142+
with SynchronizedMap[(Option[String],String), LogicalPlan]
140143

141144
abstract override def tableExists(tableIdentifier: Seq[String]): Boolean = {
142145
val tableIdent = processTableIdentifier(tableIdentifier)
@@ -235,3 +238,5 @@ object EmptyCatalog extends Catalog {
235238
throw new UnsupportedOperationException
236239
}
237240
}
241+
242+
object SimpleCaseSensitiveCatalog extends SimpleCatalog(true)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,3 +93,5 @@ class StringKeyHashMap[T](normalizer: (String) => String) {
9393
def iterator: Iterator[(String, T)] = base.toIterator
9494
}
9595

96+
object SimpleCaseSentiveFunctionRegistry extends SimpleFunctionRegistry(true)
97+
object SimpleInCaseSentiveFunctionRegistry extends SimpleFunctionRegistry(false)

sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -34,24 +34,30 @@ private case class CachedData(plan: LogicalPlan, cachedRepresentation: InMemoryR
3434
* InMemoryRelation. This relation is automatically substituted query plans that return the
3535
* `sameResult` as the originally cached query.
3636
*
37+
* TODO Cached Data (Global wide V.S. Catalog Instance wide)
3738
* Internal to Spark SQL.
3839
*/
39-
private[sql] class CacheManager(sqlContext: SQLContext) extends Logging {
40-
40+
private[sql] object CacheManager extends Logging {
4141
@transient
42-
private val cachedData = new scala.collection.mutable.ArrayBuffer[CachedData]
42+
private[this] val cachedData = new scala.collection.mutable.ArrayBuffer[CachedData]
4343

4444
@transient
45-
private val cacheLock = new ReentrantReadWriteLock
45+
private[this] val cacheLock = new ReentrantReadWriteLock
4646

4747
/** Returns true if the table is currently cached in-memory. */
48-
def isCached(tableName: String): Boolean = lookupCachedData(sqlContext.table(tableName)).nonEmpty
48+
private[sql] def isCached(sqlContext: SQLContext, tableName: String): Boolean = {
49+
lookupCachedData(sqlContext.table(tableName)).nonEmpty
50+
}
4951

5052
/** Caches the specified table in-memory. */
51-
def cacheTable(tableName: String): Unit = cacheQuery(sqlContext.table(tableName), Some(tableName))
53+
private[sql] def cacheTable(sqlContext: SQLContext, tableName: String): Unit = {
54+
cacheQuery(sqlContext.table(tableName), Some(tableName))
55+
}
5256

5357
/** Removes the specified table from the in-memory cache. */
54-
def uncacheTable(tableName: String): Unit = uncacheQuery(sqlContext.table(tableName))
58+
private[sql] def uncacheTable(sqlContext: SQLContext, tableName: String): Unit = {
59+
uncacheQuery(sqlContext.table(tableName))
60+
}
5561

5662
/** Acquires a read lock on the cache for the duration of `f`. */
5763
private def readLock[A](f: => A): A = {
@@ -99,8 +105,8 @@ private[sql] class CacheManager(sqlContext: SQLContext) extends Logging {
99105
CachedData(
100106
planToCache,
101107
InMemoryRelation(
102-
sqlContext.conf.useCompression,
103-
sqlContext.conf.columnBatchSize,
108+
query.sqlContext.conf.useCompression,
109+
query.sqlContext.conf.columnBatchSize,
104110
storageLevel,
105111
query.queryExecution.executedPlan,
106112
tableName))
@@ -162,3 +168,4 @@ private[sql] class CacheManager(sqlContext: SQLContext) extends Logging {
162168
}
163169
}
164170
}
171+

sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -819,23 +819,23 @@ class DataFrame protected[sql](
819819
* @group basic
820820
*/
821821
override def persist(): this.type = {
822-
sqlContext.cacheManager.cacheQuery(this)
822+
CacheManager.cacheQuery(this)
823823
this
824824
}
825825

826826
/**
827827
* @group basic
828828
*/
829829
override def persist(newLevel: StorageLevel): this.type = {
830-
sqlContext.cacheManager.cacheQuery(this, None, newLevel)
830+
CacheManager.cacheQuery(this, None, newLevel)
831831
this
832832
}
833833

834834
/**
835835
* @group basic
836836
*/
837837
override def unpersist(blocking: Boolean): this.type = {
838-
sqlContext.cacheManager.tryUncacheQuery(this, blocking)
838+
CacheManager.tryUncacheQuery(this, blocking)
839839
this
840840
}
841841

sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -104,10 +104,10 @@ class SQLContext(@transient val sparkContext: SparkContext)
104104
def getAllConfs: immutable.Map[String, String] = conf.getAllConfs
105105

106106
@transient
107-
protected[sql] lazy val catalog: Catalog = new SimpleCatalog(true)
107+
protected[sql] lazy val catalog: Catalog = SimpleCaseSensitiveCatalog
108108

109109
@transient
110-
protected[sql] lazy val functionRegistry: FunctionRegistry = new SimpleFunctionRegistry(true)
110+
protected[sql] lazy val functionRegistry: FunctionRegistry = SimpleCaseSentiveFunctionRegistry
111111

112112
@transient
113113
protected[sql] lazy val analyzer: Analyzer =
@@ -144,9 +144,6 @@ class SQLContext(@transient val sparkContext: SparkContext)
144144
case _ =>
145145
}
146146

147-
@transient
148-
protected[sql] val cacheManager = new CacheManager(this)
149-
150147
/**
151148
* :: Experimental ::
152149
* A collection of methods that are considered experimental, but can be used to hook into
@@ -203,24 +200,24 @@ class SQLContext(@transient val sparkContext: SparkContext)
203200
* Returns true if the table is currently cached in-memory.
204201
* @group cachemgmt
205202
*/
206-
def isCached(tableName: String): Boolean = cacheManager.isCached(tableName)
203+
def isCached(tableName: String): Boolean = CacheManager.isCached(this, tableName)
207204

208205
/**
209206
* Caches the specified table in-memory.
210207
* @group cachemgmt
211208
*/
212-
def cacheTable(tableName: String): Unit = cacheManager.cacheTable(tableName)
209+
def cacheTable(tableName: String): Unit = CacheManager.cacheTable(this, tableName)
213210

214211
/**
215212
* Removes the specified table from the in-memory cache.
216213
* @group cachemgmt
217214
*/
218-
def uncacheTable(tableName: String): Unit = cacheManager.uncacheTable(tableName)
215+
def uncacheTable(tableName: String): Unit = CacheManager.uncacheTable(this, tableName)
219216

220217
/**
221218
* Removes all cached tables from the in-memory cache.
222219
*/
223-
def clearCache(): Unit = cacheManager.clearCache()
220+
def clearCache(): Unit = CacheManager.clearCache()
224221

225222
// scalastyle:off
226223
// Disable style checker so "implicits" object can start with lowercase i
@@ -905,7 +902,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
905902
* @group basic
906903
*/
907904
def dropTempTable(tableName: String): Unit = {
908-
cacheManager.tryUncacheQuery(table(tableName))
905+
CacheManager.tryUncacheQuery(table(tableName))
909906
catalog.unregisterTable(Seq(tableName))
910907
}
911908

@@ -1066,7 +1063,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
10661063
protected[sql] class QueryExecution(val logical: LogicalPlan) {
10671064

10681065
lazy val analyzed: LogicalPlan = analyzer(logical)
1069-
lazy val withCachedData: LogicalPlan = cacheManager.useCachedData(analyzed)
1066+
lazy val withCachedData: LogicalPlan = CacheManager.useCachedData(analyzed)
10701067
lazy val optimizedPlan: LogicalPlan = optimizer(withCachedData)
10711068

10721069
// TODO: Don't just pick the first one...

sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@
1616
*/
1717
package org.apache.spark.sql.sources
1818

19-
import org.apache.spark.sql.{DataFrame, SQLContext}
19+
import org.apache.spark.sql.{CacheManager, DataFrame, SQLContext}
2020
import org.apache.spark.sql.catalyst.expressions._
2121
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
22-
import org.apache.spark.sql.execution.{LogicalRDD, RunnableCommand}
22+
import org.apache.spark.sql.execution.RunnableCommand
2323

2424
private[sql] case class InsertIntoDataSource(
2525
logicalRelation: LogicalRelation,
@@ -32,7 +32,7 @@ private[sql] case class InsertIntoDataSource(
3232
relation.insert(DataFrame(sqlContext, query), overwrite)
3333

3434
// Invalidate the cache.
35-
sqlContext.cacheManager.invalidateCache(logicalRelation)
35+
CacheManager.invalidateCache(logicalRelation)
3636

3737
Seq.empty[Row]
3838
}

sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -56,17 +56,17 @@ class CachedTableSuite extends QueryTest {
5656
}
5757

5858
test("unpersist an uncached table will not raise exception") {
59-
assert(None == cacheManager.lookupCachedData(testData))
60-
testData.unpersist(blocking = true)
61-
assert(None == cacheManager.lookupCachedData(testData))
62-
testData.unpersist(blocking = false)
63-
assert(None == cacheManager.lookupCachedData(testData))
59+
assert(None == CacheManager.lookupCachedData(testData))
60+
testData.unpersist(true)
61+
assert(None == CacheManager.lookupCachedData(testData))
62+
testData.unpersist(false)
63+
assert(None == CacheManager.lookupCachedData(testData))
6464
testData.persist()
65-
assert(None != cacheManager.lookupCachedData(testData))
66-
testData.unpersist(blocking = true)
67-
assert(None == cacheManager.lookupCachedData(testData))
68-
testData.unpersist(blocking = false)
69-
assert(None == cacheManager.lookupCachedData(testData))
65+
assert(None != CacheManager.lookupCachedData(testData))
66+
testData.unpersist(true)
67+
assert(None == CacheManager.lookupCachedData(testData))
68+
testData.unpersist(false)
69+
assert(None == CacheManager.lookupCachedData(testData))
7070
}
7171

7272
test("cache table as select") {
@@ -287,13 +287,13 @@ class CachedTableSuite extends QueryTest {
287287
cacheTable("t1")
288288
cacheTable("t2")
289289
clearCache()
290-
assert(cacheManager.isEmpty)
290+
assert(CacheManager.isEmpty)
291291

292292
sql("SELECT key FROM testData LIMIT 10").registerTempTable("t1")
293293
sql("SELECT key FROM testData LIMIT 5").registerTempTable("t2")
294294
cacheTable("t1")
295295
cacheTable("t2")
296296
sql("Clear CACHE")
297-
assert(cacheManager.isEmpty)
297+
assert(CacheManager.isEmpty)
298298
}
299299
}

sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
6060
}
6161

6262
test("join operator selection") {
63-
cacheManager.clearCache()
63+
CacheManager.clearCache()
6464

6565
Seq(
6666
("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash]),
@@ -94,7 +94,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
9494
}
9595

9696
test("broadcasted hash join operator selection") {
97-
cacheManager.clearCache()
97+
CacheManager.clearCache()
9898
sql("CACHE TABLE testData")
9999

100100
Seq(
@@ -385,7 +385,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
385385
}
386386

387387
test("broadcasted left semi join operator selection") {
388-
cacheManager.clearCache()
388+
CacheManager.clearCache()
389389
sql("CACHE TABLE testData")
390390
val tmp = conf.autoBroadcastJoinThreshold
391391

sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ import org.apache.hadoop.hive.conf.HiveConf.ConfVars
2323
import org.apache.hive.service.cli.thrift.{ThriftBinaryCLIService, ThriftHttpCLIService}
2424
import org.apache.hive.service.server.{HiveServer2, ServerOptionsProcessor}
2525

26-
import org.apache.spark.Logging
26+
import org.apache.spark.{SparkContext, Logging}
2727
import org.apache.spark.annotation.DeveloperApi
2828
import org.apache.spark.sql.hive.HiveContext
2929
import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._
@@ -39,10 +39,11 @@ object HiveThriftServer2 extends Logging {
3939
/**
4040
* :: DeveloperApi ::
4141
* Starts a new thrift server with the given context.
42+
* TODO probably a SparkContext, and HiveConf as parameter would be better
4243
*/
4344
@DeveloperApi
4445
def startWithContext(sqlContext: HiveContext): Unit = {
45-
val server = new HiveThriftServer2(sqlContext)
46+
val server = new HiveThriftServer2(sqlContext.sparkContext)
4647
server.init(sqlContext.hiveconf)
4748
server.start()
4849
sqlContext.sparkContext.addSparkListener(new HiveThriftServer2Listener(server))
@@ -66,7 +67,7 @@ object HiveThriftServer2 extends Logging {
6667
)
6768

6869
try {
69-
val server = new HiveThriftServer2(SparkSQLEnv.hiveContext)
70+
val server = new HiveThriftServer2(SparkSQLEnv.sparkContext)
7071
server.init(SparkSQLEnv.hiveContext.hiveconf)
7172
server.start()
7273
logInfo("HiveThriftServer2 started")
@@ -89,12 +90,12 @@ object HiveThriftServer2 extends Logging {
8990

9091
}
9192

92-
private[hive] class HiveThriftServer2(hiveContext: HiveContext)
93+
private[hive] class HiveThriftServer2(sc: SparkContext)
9394
extends HiveServer2
9495
with ReflectedCompositeService {
9596

9697
override def init(hiveConf: HiveConf) {
97-
val sparkSqlCliService = new SparkSQLCLIService(hiveContext)
98+
val sparkSqlCliService = new SparkSQLCLIService(sc)
9899
setSuperField(this, "cliService", sparkSqlCliService)
99100
addService(sparkSqlCliService)
100101

sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ import java.io.IOException
2121
import java.util.{List => JList}
2222
import javax.security.auth.login.LoginException
2323

24+
import org.apache.spark.SparkContext
25+
2426
import scala.collection.JavaConversions._
2527

2628
import org.apache.commons.logging.Log
@@ -36,14 +38,14 @@ import org.apache.spark.sql.hive.HiveContext
3638
import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._
3739
import org.apache.spark.util.Utils
3840

39-
private[hive] class SparkSQLCLIService(hiveContext: HiveContext)
41+
private[hive] class SparkSQLCLIService(sc: SparkContext)
4042
extends CLIService
4143
with ReflectedCompositeService {
4244

4345
override def init(hiveConf: HiveConf) {
4446
setSuperField(this, "hiveConf", hiveConf)
4547

46-
val sparkSqlSessionManager = new SparkSQLSessionManager(hiveContext)
48+
val sparkSqlSessionManager = new SparkSQLSessionManager(sc)
4749
setSuperField(this, "sessionManager", sparkSqlSessionManager)
4850
addService(sparkSqlSessionManager)
4951
var sparkServiceUGI: UserGroupInformation = null
@@ -66,7 +68,7 @@ private[hive] class SparkSQLCLIService(hiveContext: HiveContext)
6668
getInfoType match {
6769
case GetInfoType.CLI_SERVER_NAME => new GetInfoValue("Spark SQL")
6870
case GetInfoType.CLI_DBMS_NAME => new GetInfoValue("Spark SQL")
69-
case GetInfoType.CLI_DBMS_VER => new GetInfoValue(hiveContext.sparkContext.version)
71+
case GetInfoType.CLI_DBMS_VER => new GetInfoValue(sc.version)
7072
case _ => super.getInfo(sessionHandle, getInfoType)
7173
}
7274
}

0 commit comments

Comments
 (0)