Skip to content

Commit 599dcb0

Browse files
committed
Merge pull request alteryx#74 from rxin/kill
Job cancellation via job group id. This PR adds a simple API to group together a set of jobs belonging to a thread and threads spawned from it. It also allows the cancellation of all jobs in this group. An example: sc.setJobDescription("this_is_the_group_id", "some job description") sc.parallelize(1 to 10000, 2).map { i => Thread.sleep(10); i }.count() In a separate thread: sc.cancelJobGroup("this_is_the_group_id")
2 parents 8de9706 + 806f3a3 commit 599dcb0

File tree

4 files changed

+75
-7
lines changed

4 files changed

+75
-7
lines changed

core/src/main/scala/org/apache/spark/SparkContext.scala

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -287,8 +287,19 @@ class SparkContext(
287287
Option(localProperties.get).map(_.getProperty(key)).getOrElse(null)
288288

289289
/** Set a human readable description of the current job. */
290+
@deprecated("use setJobGroup", "0.8.1")
290291
def setJobDescription(value: String) {
291-
setLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION, value)
292+
setJobGroup("", value)
293+
}
294+
295+
def setJobGroup(groupId: String, description: String) {
296+
setLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION, description)
297+
setLocalProperty(SparkContext.SPARK_JOB_GROUP_ID, groupId)
298+
}
299+
300+
def clearJobGroup() {
301+
setLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION, null)
302+
setLocalProperty(SparkContext.SPARK_JOB_GROUP_ID, null)
292303
}
293304

294305
// Post init
@@ -866,10 +877,14 @@ class SparkContext(
866877
callSite,
867878
allowLocal = false,
868879
resultHandler,
869-
null)
880+
localProperties.get)
870881
new SimpleFutureAction(waiter, resultFunc)
871882
}
872883

884+
def cancelJobGroup(groupId: String) {
885+
dagScheduler.cancelJobGroup(groupId)
886+
}
887+
873888
/**
874889
* Cancel all jobs that have been scheduled or are running.
875890
*/
@@ -933,8 +948,11 @@ class SparkContext(
933948
* various Spark features.
934949
*/
935950
object SparkContext {
951+
936952
val SPARK_JOB_DESCRIPTION = "spark.job.description"
937953

954+
val SPARK_JOB_GROUP_ID = "spark.jobGroup.id"
955+
938956
implicit object DoubleAccumulatorParam extends AccumulatorParam[Double] {
939957
def addInPlace(t1: Double, t2: Double): Double = t1 + t2
940958
def zero(initialValue: Double) = 0.0

core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -277,11 +277,6 @@ class DAGScheduler(
277277
resultHandler: (Int, U) => Unit,
278278
properties: Properties = null): JobWaiter[U] =
279279
{
280-
val jobId = nextJobId.getAndIncrement()
281-
if (partitions.size == 0) {
282-
return new JobWaiter[U](this, jobId, 0, resultHandler)
283-
}
284-
285280
// Check to make sure we are not launching a task on a partition that does not exist.
286281
val maxPartitions = rdd.partitions.length
287282
partitions.find(p => p >= maxPartitions).foreach { p =>
@@ -290,6 +285,11 @@ class DAGScheduler(
290285
"Total number of partitions: " + maxPartitions)
291286
}
292287

288+
val jobId = nextJobId.getAndIncrement()
289+
if (partitions.size == 0) {
290+
return new JobWaiter[U](this, jobId, 0, resultHandler)
291+
}
292+
293293
assert(partitions.size > 0)
294294
val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
295295
val waiter = new JobWaiter(this, jobId, partitions.size, resultHandler)
@@ -342,6 +342,11 @@ class DAGScheduler(
342342
eventQueue.put(JobCancelled(jobId))
343343
}
344344

345+
def cancelJobGroup(groupId: String) {
346+
logInfo("Asked to cancel job group " + groupId)
347+
eventQueue.put(JobGroupCancelled(groupId))
348+
}
349+
345350
/**
346351
* Cancel all jobs that are running or waiting in the queue.
347352
*/
@@ -381,6 +386,17 @@ class DAGScheduler(
381386
taskSched.cancelTasks(stage.id)
382387
}
383388

389+
case JobGroupCancelled(groupId) =>
390+
// Cancel all jobs belonging to this job group.
391+
// First finds all active jobs with this group id, and then kill stages for them.
392+
val jobIds = activeJobs.filter(groupId == _.properties.get(SparkContext.SPARK_JOB_GROUP_ID))
393+
.map(_.jobId)
394+
if (!jobIds.isEmpty) {
395+
running.filter(stage => jobIds.contains(stage.jobId)).foreach { stage =>
396+
taskSched.cancelTasks(stage.id)
397+
}
398+
}
399+
384400
case AllJobsCancelled =>
385401
// Cancel all running jobs.
386402
running.foreach { stage =>

core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ private[scheduler] case class JobSubmitted(
4646

4747
private[scheduler] case class JobCancelled(jobId: Int) extends DAGSchedulerEvent
4848

49+
private[scheduler] case class JobGroupCancelled(groupId: String) extends DAGSchedulerEvent
50+
4951
private[scheduler] case object AllJobsCancelled extends DAGSchedulerEvent
5052

5153
private[scheduler]

core/src/test/scala/org/apache/spark/JobCancellationSuite.scala

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ package org.apache.spark
1919

2020
import java.util.concurrent.Semaphore
2121

22+
import scala.concurrent.Await
23+
import scala.concurrent.duration.Duration
2224
import scala.concurrent.future
2325
import scala.concurrent.ExecutionContext.Implicits.global
2426

@@ -83,6 +85,36 @@ class JobCancellationSuite extends FunSuite with ShouldMatchers with BeforeAndAf
8385
assert(sc.parallelize(1 to 10, 2).count === 10)
8486
}
8587

88+
test("job group") {
89+
sc = new SparkContext("local[2]", "test")
90+
91+
// Add a listener to release the semaphore once any tasks are launched.
92+
val sem = new Semaphore(0)
93+
sc.dagScheduler.addSparkListener(new SparkListener {
94+
override def onTaskStart(taskStart: SparkListenerTaskStart) {
95+
sem.release()
96+
}
97+
})
98+
99+
// jobA is the one to be cancelled.
100+
val jobA = future {
101+
sc.setJobGroup("jobA", "this is a job to be cancelled")
102+
sc.parallelize(1 to 10000, 2).map { i => Thread.sleep(10); i }.count()
103+
}
104+
105+
sc.clearJobGroup()
106+
val jobB = sc.parallelize(1 to 100, 2).countAsync()
107+
108+
// Block until both tasks of job A have started and cancel job A.
109+
sem.acquire(2)
110+
sc.cancelJobGroup("jobA")
111+
val e = intercept[SparkException] { Await.result(jobA, Duration.Inf) }
112+
assert(e.getMessage contains "cancel")
113+
114+
// Once A is cancelled, job B should finish fairly quickly.
115+
assert(jobB.get() === 100)
116+
}
117+
86118
test("two jobs sharing the same stage") {
87119
// sem1: make sure cancel is issued after some tasks are launched
88120
// sem2: make sure the first stage is not finished until cancel is issued

0 commit comments

Comments
 (0)