Skip to content

Commit 6ab9a0f

Browse files
committed
Remove duplicate code in tests.
1 parent 9ea2061 commit 6ab9a0f

File tree

1 file changed

+53
-57
lines changed

1 file changed

+53
-57
lines changed

core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala

Lines changed: 53 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package org.apache.spark.deploy.master
1919

2020
import java.util.Date
2121
import java.util.concurrent.ConcurrentLinkedQueue
22+
import java.util.concurrent.atomic.AtomicInteger
2223

2324
import scala.collection.JavaConverters._
2425
import scala.collection.mutable.HashMap
@@ -35,7 +36,41 @@ import other.supplier.{CustomPersistenceEngine, CustomRecoveryModeFactory}
3536
import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite}
3637
import org.apache.spark.deploy._
3738
import org.apache.spark.deploy.DeployMessages._
38-
import org.apache.spark.rpc.{RpcAddress, RpcEndpoint, RpcEnv}
39+
import org.apache.spark.rpc._
40+
41+
object MockWorker {
42+
val counter = new AtomicInteger(10000)
43+
}
44+
45+
class MockWorker(master: RpcEndpointRef, conf: SparkConf = new SparkConf) extends RpcEndpoint {
46+
val seq = MockWorker.counter.incrementAndGet()
47+
val id = seq.toString
48+
override val rpcEnv: RpcEnv = RpcEnv.create("worker", "localhost", seq,
49+
conf, new SecurityManager(conf))
50+
var appRegistered = false
51+
def newDriver(): RpcEndpointRef = {
52+
val name = s"driver_${drivers.size}"
53+
rpcEnv.setupEndpoint(name, new RpcEndpoint {
54+
override val rpcEnv: RpcEnv = MockWorker.this.rpcEnv
55+
override def receive: PartialFunction[Any, Unit] = {
56+
case RegisteredApplication(_, _) => appRegistered = true
57+
}
58+
})
59+
}
60+
61+
val appDesc = DeployTestUtils.createAppDesc()
62+
val drivers = new HashMap[String, String]
63+
override def receive: PartialFunction[Any, Unit] = {
64+
case RegisteredWorker(masterRef, _, _) =>
65+
masterRef.send(WorkerLatestState("1", Nil, drivers.keys.toSeq))
66+
case LaunchDriver(driverId, desc) =>
67+
drivers(driverId) = driverId
68+
master.send(RegisterApplication(appDesc, newDriver()))
69+
case KillDriver(driverId) =>
70+
master.send(DriverStateChanged(driverId, DriverState.KILLED, None))
71+
drivers.remove(driverId)
72+
}
73+
}
3974

4075
class MasterSuite extends SparkFunSuite
4176
with Matchers with Eventually with PrivateMethodTester with BeforeAndAfter {
@@ -509,93 +544,54 @@ class MasterSuite extends SparkFunSuite
509544
val masterState = master.self.askSync[MasterStateResponse](RequestMasterState)
510545
assert(masterState.status === RecoveryState.ALIVE, "Master is not alive")
511546
}
512-
513-
val app = DeployTestUtils.createAppDesc()
514-
var appId = ""
515-
val driverEnv1 = RpcEnv.create("driver1", "localhost", 22344, conf, new SecurityManager(conf))
516-
val fakeDriver1 = driverEnv1.setupEndpoint("driver", new RpcEndpoint {
517-
override val rpcEnv: RpcEnv = driverEnv1
518-
override def receive: PartialFunction[Any, Unit] = {
519-
case RegisteredApplication(id, _) => appId = id
520-
}
521-
})
522-
val drivers = new HashMap[String, String]
523-
val workerEnv1 = RpcEnv.create("worker1", "localhost", 12344, conf, new SecurityManager(conf))
524-
val fakeWorker1 = workerEnv1.setupEndpoint("worker", new RpcEndpoint {
525-
override val rpcEnv: RpcEnv = workerEnv1
526-
override def receive: PartialFunction[Any, Unit] = {
527-
case RegisteredWorker(masterRef, _, _) =>
528-
masterRef.send(WorkerLatestState("1", Nil, drivers.keys.toSeq))
529-
case LaunchDriver(id, desc) =>
530-
drivers(id) = id
531-
master.self.send(RegisterApplication(app, fakeDriver1))
532-
case KillDriver(driverId) =>
533-
master.self.send(DriverStateChanged(driverId, DriverState.KILLED, None))
534-
drivers.remove(driverId)
535-
}
536-
})
537-
val worker1 = RegisterWorker(
538-
"1",
547+
val worker1 = new MockWorker(master.self)
548+
worker1.rpcEnv.setupEndpoint("worker", worker1)
549+
val worker1Reg = RegisterWorker(
550+
worker1.id,
539551
"localhost",
540-
9999,
541-
fakeWorker1,
552+
9998,
553+
worker1.self,
542554
10,
543555
1024,
544556
"http://localhost:8080",
545557
RpcAddress("localhost2", 10000))
546-
master.self.send(worker1)
558+
master.self.send(worker1Reg)
547559
val driver = DeployTestUtils.createDriverDesc().copy(supervise = true)
548560
master.self.askSync[SubmitDriverResponse](RequestSubmitDriver(driver))
549561

550562
eventually(timeout(10.seconds)) {
551-
assert(!appId.isEmpty)
563+
assert(!worker1.appRegistered)
552564
}
553565

554566
eventually(timeout(10.seconds)) {
555567
val masterState = master.self.askSync[MasterStateResponse](RequestMasterState)
556568
assert(masterState.workers(0).state == WorkerState.DEAD)
557569
}
558570

559-
val driverEnv2 = RpcEnv.create("driver2", "localhost", 22345, conf, new SecurityManager(conf))
560-
val fakeDriver2 = driverEnv2.setupEndpoint("driver", new RpcEndpoint {
561-
override val rpcEnv: RpcEnv = driverEnv2
562-
override def receive: PartialFunction[Any, Unit] = {
563-
case RegisteredApplication(id, _) => appId = id
564-
}
565-
})
566-
val workerEnv2 = RpcEnv.create("worker2", "localhost", 12345, conf, new SecurityManager(conf))
567-
val fakeWorker2 = workerEnv2.setupEndpoint("worker2", new RpcEndpoint {
568-
override val rpcEnv: RpcEnv = workerEnv2
569-
override def receive: PartialFunction[Any, Unit] = {
570-
case LaunchDriver(_, _) =>
571-
master.self.send(RegisterApplication(app, fakeDriver2))
572-
}
573-
})
574-
575-
appId = ""
571+
val worker2 = new MockWorker(master.self)
572+
worker2.rpcEnv.setupEndpoint("worker", worker2)
576573
master.self.send(RegisterWorker(
577-
"2",
574+
worker2.id,
578575
"localhost",
579-
9998,
580-
fakeWorker2,
576+
9999,
577+
worker2.self,
581578
10,
582579
1024,
583580
"http://localhost:8081",
584-
RpcAddress("localhost2", 10001)))
581+
RpcAddress("localhost", 10001)))
585582
eventually(timeout(10.seconds)) {
586-
assert(!appId.isEmpty)
583+
assert(!worker2.appRegistered)
587584
}
588585

589-
master.self.send(worker1)
586+
master.self.send(worker1Reg)
590587
eventually(timeout(10.seconds)) {
591588
val masterState = master.self.askSync[MasterStateResponse](RequestMasterState)
592589

593-
val worker = masterState.workers.filter(w => w.id == "1")
590+
val worker = masterState.workers.filter(w => w.id == worker1.id)
594591
assert(worker.length == 1)
595592
// make sure the `DriverStateChanged` arrives at Master.
596593
assert(worker(0).drivers.isEmpty)
597594
assert(masterState.activeDrivers.length == 1)
598-
assert(masterState.activeDrivers(0).state == DriverState.RUNNING)
599595
assert(masterState.activeApps.length == 1)
600596
}
601597
}

0 commit comments

Comments
 (0)