1818package org .apache .spark .sql .execution .streaming .state
1919
2020import java .util .concurrent .{ScheduledFuture , TimeUnit }
21+ import javax .annotation .concurrent .GuardedBy
2122
2223import scala .collection .mutable
2324import scala .util .control .NonFatal
@@ -124,12 +125,46 @@ object StateStore extends Logging {
124125 val MAINTENANCE_INTERVAL_CONFIG = " spark.sql.streaming.stateStore.maintenanceInterval"
125126 val MAINTENANCE_INTERVAL_DEFAULT_SECS = 60
126127
128+ @ GuardedBy (" loadedProviders" )
127129 private val loadedProviders = new mutable.HashMap [StateStoreId , StateStoreProvider ]()
128- private val maintenanceTaskExecutor =
129- ThreadUtils .newDaemonSingleThreadScheduledExecutor(" state-store-maintenance-task" )
130130
131- @ volatile private var maintenanceTask : ScheduledFuture [_] = null
132- @ volatile private var _coordRef : StateStoreCoordinatorRef = null
131+ /**
132+ * Runs the `task` periodically and automatically cancels it if there is an exception. `onError`
133+ * will be called when an exception happens.
134+ */
135+ class MaintenanceTask (periodMs : Long , task : => Unit , onError : => Unit ) {
136+ private val executor =
137+ ThreadUtils .newDaemonSingleThreadScheduledExecutor(" state-store-maintenance-task" )
138+
139+ private val runnable = new Runnable {
140+ override def run (): Unit = {
141+ try {
142+ task
143+ } catch {
144+ case NonFatal (e) =>
145+ logWarning(" Error running maintenance thread" , e)
146+ onError
147+ throw e
148+ }
149+ }
150+ }
151+
152+ private val future : ScheduledFuture [_] = executor.scheduleAtFixedRate(
153+ runnable, periodMs, periodMs, TimeUnit .MILLISECONDS )
154+
155+ def stop (): Unit = {
156+ future.cancel(false )
157+ executor.shutdown()
158+ }
159+
160+ def isRunning : Boolean = ! future.isDone
161+ }
162+
163+ @ GuardedBy (" loadedProviders" )
164+ private var maintenanceTask : MaintenanceTask = null
165+
166+ @ GuardedBy (" loadedProviders" )
167+ private var _coordRef : StateStoreCoordinatorRef = null
133168
134169 /** Get or create a store associated with the id. */
135170 def get (
@@ -162,15 +197,15 @@ object StateStore extends Logging {
162197 }
163198
164199 def isMaintenanceRunning : Boolean = loadedProviders.synchronized {
165- maintenanceTask != null
200+ maintenanceTask != null && maintenanceTask.isRunning
166201 }
167202
168203 /** Unload and stop all state store providers */
169204 def stop (): Unit = loadedProviders.synchronized {
170205 loadedProviders.clear()
171206 _coordRef = null
172207 if (maintenanceTask != null ) {
173- maintenanceTask.cancel( false )
208+ maintenanceTask.stop( )
174209 maintenanceTask = null
175210 }
176211 logInfo(" StateStore stopped" )
@@ -179,14 +214,14 @@ object StateStore extends Logging {
179214 /** Start the periodic maintenance task if not already started and if Spark active */
180215 private def startMaintenanceIfNeeded (): Unit = loadedProviders.synchronized {
181216 val env = SparkEnv .get
182- if (maintenanceTask == null && env != null ) {
217+ if (env != null && ! isMaintenanceRunning ) {
183218 val periodMs = env.conf.getTimeAsMs(
184219 MAINTENANCE_INTERVAL_CONFIG , s " ${MAINTENANCE_INTERVAL_DEFAULT_SECS }s " )
185- val runnable = new Runnable {
186- override def run () : Unit = { doMaintenance() }
187- }
188- maintenanceTask = maintenanceTaskExecutor.scheduleAtFixedRate(
189- runnable, periodMs, periodMs, TimeUnit . MILLISECONDS )
220+ maintenanceTask = new MaintenanceTask (
221+ periodMs,
222+ task = { doMaintenance() },
223+ onError = { loadedProviders. synchronized { loadedProviders.clear() } }
224+ )
190225 logInfo(" State Store maintenance task started" )
191226 }
192227 }
@@ -198,21 +233,20 @@ object StateStore extends Logging {
198233 private def doMaintenance (): Unit = {
199234 logDebug(" Doing maintenance" )
200235 if (SparkEnv .get == null ) {
201- stop()
202- } else {
203- loadedProviders.synchronized { loadedProviders.toSeq }.foreach { case (id, provider) =>
204- try {
205- if (verifyIfStoreInstanceActive(id)) {
206- provider.doMaintenance()
207- } else {
208- unload(id)
209- logInfo(s " Unloaded $provider" )
210- }
211- } catch {
212- case NonFatal (e) =>
213- logWarning(s " Error managing $provider, stopping management thread " )
214- stop()
236+ throw new IllegalStateException (" SparkEnv not active, cannot do maintenance on StateStores" )
237+ }
238+ loadedProviders.synchronized { loadedProviders.toSeq }.foreach { case (id, provider) =>
239+ try {
240+ if (verifyIfStoreInstanceActive(id)) {
241+ provider.doMaintenance()
242+ } else {
243+ unload(id)
244+ logInfo(s " Unloaded $provider" )
215245 }
246+ } catch {
247+ case NonFatal (e) =>
248+ logWarning(s " Error managing $provider, stopping management thread " )
249+ throw e
216250 }
217251 }
218252 }
@@ -238,7 +272,7 @@ object StateStore extends Logging {
238272 }
239273 }
240274
241- private def coordinatorRef : Option [StateStoreCoordinatorRef ] = synchronized {
275+ private def coordinatorRef : Option [StateStoreCoordinatorRef ] = loadedProviders. synchronized {
242276 val env = SparkEnv .get
243277 if (env != null ) {
244278 if (_coordRef == null ) {
0 commit comments