|
17 | 17 |
|
18 | 18 | package org.apache.spark |
19 | 19 |
|
| 20 | +import java.util.concurrent.{Callable, Executors, ExecutorService, CyclicBarrier} |
| 21 | + |
20 | 22 | import org.scalatest.Matchers |
21 | 23 |
|
22 | 24 | import org.apache.spark.ShuffleSuite.NonJavaSerializableClass |
| 25 | +import org.apache.spark.memory.TaskMemoryManager |
23 | 26 | import org.apache.spark.rdd.{CoGroupedRDD, OrderedRDDFunctions, RDD, ShuffledRDD, SubtractedRDD} |
24 | | -import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd} |
| 27 | +import org.apache.spark.scheduler.{MyRDD, MapStatus, SparkListener, SparkListenerTaskEnd} |
25 | 28 | import org.apache.spark.serializer.KryoSerializer |
| 29 | +import org.apache.spark.shuffle.ShuffleWriter |
26 | 30 | import org.apache.spark.storage.{ShuffleDataBlockId, ShuffleBlockId} |
27 | 31 | import org.apache.spark.util.MutablePair |
28 | 32 |
|
@@ -317,6 +321,107 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC |
317 | 321 | assert(metrics.bytesWritten === metrics.byresRead) |
318 | 322 | assert(metrics.bytesWritten > 0) |
319 | 323 | } |
| 324 | + |
| 325 | + test("multiple simultaneous attempts for one task (SPARK-8029)") { |
| 326 | + sc = new SparkContext("local", "test", conf) |
| 327 | + val mapTrackerMaster = sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] |
| 328 | + val manager = sc.env.shuffleManager |
| 329 | + |
| 330 | + val taskMemoryManager = new TaskMemoryManager(sc.env.memoryManager, 0L) |
| 331 | + val metricsSystem = sc.env.metricsSystem |
| 332 | + val shuffleMapRdd = new MyRDD(sc, 1, Nil) |
| 333 | + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(1)) |
| 334 | + val shuffleHandle = manager.registerShuffle(0, 1, shuffleDep) |
| 335 | + |
| 336 | + // first attempt -- its successful |
| 337 | + val writer1 = manager.getWriter[Int, Int](shuffleHandle, 0, |
| 338 | + new TaskContextImpl(0, 0, 0L, 0, taskMemoryManager, metricsSystem, |
| 339 | + InternalAccumulator.create(sc))) |
| 340 | + val data1 = (1 to 10).map { x => x -> x} |
| 341 | + |
| 342 | + // second attempt -- also successful. We'll write out different data, |
| 343 | + // just to simulate the fact that the records may get written differently |
| 344 | + // depending on what gets spilled, what gets combined, etc. |
| 345 | + val writer2 = manager.getWriter[Int, Int](shuffleHandle, 0, |
| 346 | + new TaskContextImpl(0, 0, 1L, 0, taskMemoryManager, metricsSystem, |
| 347 | + InternalAccumulator.create(sc))) |
| 348 | + val data2 = (11 to 20).map { x => x -> x} |
| 349 | + |
| 350 | + // interleave writes of both attempts -- we want to test that both attempts can occur |
| 351 | + // simultaneously, and everything is still OK |
| 352 | + |
| 353 | + def writeAndClose( |
| 354 | + writer: ShuffleWriter[Int, Int])( |
| 355 | + iter: Iterator[(Int, Int)]): Option[MapStatus] = { |
| 356 | + val files = writer.write(iter) |
| 357 | + writer.stop(true) |
| 358 | + } |
| 359 | + val interleaver = new InterleaveIterators( |
| 360 | + data1, writeAndClose(writer1), data2, writeAndClose(writer2)) |
| 361 | + val (mapOutput1, mapOutput2) = interleaver.run() |
| 362 | + |
| 363 | + // check that we can read the map output and it has the right data |
| 364 | + assert(mapOutput1.isDefined) |
| 365 | + assert(mapOutput2.isDefined) |
| 366 | + assert(mapOutput1.get.location === mapOutput2.get.location) |
| 367 | + assert(mapOutput1.get.getSizeForBlock(0) === mapOutput1.get.getSizeForBlock(0)) |
| 368 | + |
| 369 | + // register one of the map outputs -- doesn't matter which one |
| 370 | + mapOutput1.foreach { case mapStatus => |
| 371 | + mapTrackerMaster.registerMapOutputs(0, Array(mapStatus)) |
| 372 | + } |
| 373 | + |
| 374 | + val reader = manager.getReader[Int, Int](shuffleHandle, 0, 1, |
| 375 | + new TaskContextImpl(1, 0, 2L, 0, taskMemoryManager, metricsSystem, |
| 376 | + InternalAccumulator.create(sc))) |
| 377 | + val readData = reader.read().toIndexedSeq |
| 378 | + assert(readData === data1.toIndexedSeq || readData === data2.toIndexedSeq) |
| 379 | + |
| 380 | + manager.unregisterShuffle(0) |
| 381 | + } |
| 382 | +} |
| 383 | + |
| 384 | +/** |
| 385 | + * Utility to help tests make sure that we can process two different iterators simultaneously |
| 386 | + * in different threads. This makes sure that in your test, you don't completely process data1 with |
| 387 | + * f1 before processing data2 with f2 (or vice versa). It adds a barrier so that the functions only |
| 388 | + * process one element, before pausing to wait for the other function to "catch up". |
| 389 | + */ |
| 390 | +class InterleaveIterators[T, R]( |
| 391 | + data1: Seq[T], |
| 392 | + f1: Iterator[T] => R, |
| 393 | + data2: Seq[T], |
| 394 | + f2: Iterator[T] => R) { |
| 395 | + |
| 396 | + require(data1.size == data2.size) |
| 397 | + |
| 398 | + val barrier = new CyclicBarrier(2) |
| 399 | + class BarrierIterator[E](id: Int, sub: Iterator[E]) extends Iterator[E] { |
| 400 | + def hasNext: Boolean = sub.hasNext |
| 401 | + |
| 402 | + def next: E = { |
| 403 | + barrier.await() |
| 404 | + sub.next() |
| 405 | + } |
| 406 | + } |
| 407 | + |
| 408 | + val c1 = new Callable[R] { |
| 409 | + override def call(): R = f1(new BarrierIterator(1, data1.iterator)) |
| 410 | + } |
| 411 | + val c2 = new Callable[R] { |
| 412 | + override def call(): R = f2(new BarrierIterator(2, data2.iterator)) |
| 413 | + } |
| 414 | + |
| 415 | + val e: ExecutorService = Executors.newFixedThreadPool(2) |
| 416 | + |
| 417 | + def run(): (R, R) = { |
| 418 | + val future1 = e.submit(c1) |
| 419 | + val future2 = e.submit(c2) |
| 420 | + val r1 = future1.get() |
| 421 | + val r2 = future2.get() |
| 422 | + e.shutdown() |
| 423 | + (r1, r2) |
| 424 | + } |
320 | 425 | } |
321 | 426 |
|
322 | 427 | object ShuffleSuite { |
|
0 commit comments