@@ -19,6 +19,7 @@ package org.apache.spark.streaming.util
1919
2020import java .nio .ByteBuffer
2121import java .util .{Iterator => JIterator }
22+ import java .util .concurrent .atomic .AtomicBoolean
2223import java .util .concurrent .LinkedBlockingQueue
2324
2425import scala .collection .JavaConverters ._
@@ -60,7 +61,7 @@ private[util] class BatchedWriteAheadLog(val wrappedLog: WriteAheadLog, conf: Sp
6061 private val walWriteQueue = new LinkedBlockingQueue [Record ]()
6162
6263 // Whether the writer thread is active
63- @ volatile private var active : Boolean = true
64+ private val active : AtomicBoolean = new AtomicBoolean ( true )
6465 private val buffer = new ArrayBuffer [Record ]()
6566
6667 private val batchedWriterThread = startBatchedWriterThread()
@@ -72,7 +73,7 @@ private[util] class BatchedWriteAheadLog(val wrappedLog: WriteAheadLog, conf: Sp
7273 override def write (byteBuffer : ByteBuffer , time : Long ): WriteAheadLogRecordHandle = {
7374 val promise = Promise [WriteAheadLogRecordHandle ]()
7475 val putSuccessfully = synchronized {
75- if (active) {
76+ if (active.get() ) {
7677 walWriteQueue.offer(Record (byteBuffer, time, promise))
7778 true
7879 } else {
@@ -121,10 +122,7 @@ private[util] class BatchedWriteAheadLog(val wrappedLog: WriteAheadLog, conf: Sp
121122 */
122123 override def close (): Unit = {
123124 logInfo(s " BatchedWriteAheadLog shutting down at time: ${System .currentTimeMillis()}. " )
124- synchronized {
125- if (! active) return
126- active = false
127- }
125+ if (! active.getAndSet(false )) return
128126 batchedWriterThread.interrupt()
129127 batchedWriterThread.join()
130128 while (! walWriteQueue.isEmpty) {
@@ -139,7 +137,7 @@ private[util] class BatchedWriteAheadLog(val wrappedLog: WriteAheadLog, conf: Sp
139137 private def startBatchedWriterThread (): Thread = {
140138 val thread = new Thread (new Runnable {
141139 override def run (): Unit = {
142- while (active) {
140+ while (active.get() ) {
143141 try {
144142 flushRecords()
145143 } catch {
@@ -167,7 +165,7 @@ private[util] class BatchedWriteAheadLog(val wrappedLog: WriteAheadLog, conf: Sp
167165 }
168166 try {
169167 var segment : WriteAheadLogRecordHandle = null
170- if (buffer.length > 0 ) {
168+ if (buffer.nonEmpty ) {
171169 logDebug(s " Batched ${buffer.length} records for Write Ahead Log write " )
172170 // threads may not be able to add items in order by time
173171 val sortedByTime = buffer.sortBy(_.time)
0 commit comments