Skip to content

Commit d386cdf

Browse files
committed
merged
2 parents 5439468 + d9ad789 commit d386cdf

File tree

76 files changed

+1074
-434
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

76 files changed

+1074
-434
lines changed

R/pkg/vignettes/sparkr-vignettes.Rmd

Lines changed: 35 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -503,6 +503,8 @@ SparkR supports the following machine learning models and algorithms.
503503

504504
#### Tree - Classification and Regression
505505

506+
* Decision Tree
507+
506508
* Gradient-Boosted Trees (GBT)
507509

508510
* Random Forest
@@ -776,16 +778,32 @@ newDF <- createDataFrame(data.frame(x = c(1.5, 3.2)))
776778
head(predict(isoregModel, newDF))
777779
```
778780

781+
#### Decision Tree
782+
783+
`spark.decisionTree` fits a [decision tree](https://en.wikipedia.org/wiki/Decision_tree_learning) classification or regression model on a `SparkDataFrame`.
784+
Users can call `summary` to get a summary of the fitted model, `predict` to make predictions, and `write.ml`/`read.ml` to save/load fitted models.
785+
786+
We use the `Titanic` dataset to train a decision tree and make predictions:
787+
788+
```{r}
789+
t <- as.data.frame(Titanic)
790+
df <- createDataFrame(t)
791+
dtModel <- spark.decisionTree(df, Survived ~ ., type = "classification", maxDepth = 2)
792+
summary(dtModel)
793+
predictions <- predict(dtModel, df)
794+
```
795+
779796
#### Gradient-Boosted Trees
780797

781798
`spark.gbt` fits a [gradient-boosted tree](https://en.wikipedia.org/wiki/Gradient_boosting) classification or regression model on a `SparkDataFrame`.
782799
Users can call `summary` to get a summary of the fitted model, `predict` to make predictions, and `write.ml`/`read.ml` to save/load fitted models.
783800

784-
We use the `longley` dataset to train a gradient-boosted tree and make predictions:
801+
We use the `Titanic` dataset to train a gradient-boosted tree and make predictions:
785802

786-
```{r, warning=FALSE}
787-
df <- createDataFrame(longley)
788-
gbtModel <- spark.gbt(df, Employed ~ ., type = "regression", maxDepth = 2, maxIter = 2)
803+
```{r}
804+
t <- as.data.frame(Titanic)
805+
df <- createDataFrame(t)
806+
gbtModel <- spark.gbt(df, Survived ~ ., type = "classification", maxDepth = 2, maxIter = 2)
789807
summary(gbtModel)
790808
predictions <- predict(gbtModel, df)
791809
```
@@ -795,11 +813,12 @@ predictions <- predict(gbtModel, df)
795813
`spark.randomForest` fits a [random forest](https://en.wikipedia.org/wiki/Random_forest) classification or regression model on a `SparkDataFrame`.
796814
Users can call `summary` to get a summary of the fitted model, `predict` to make predictions, and `write.ml`/`read.ml` to save/load fitted models.
797815

798-
In the following example, we use the `longley` dataset to train a random forest and make predictions:
816+
In the following example, we use the `Titanic` dataset to train a random forest and make predictions:
799817

800-
```{r, warning=FALSE}
801-
df <- createDataFrame(longley)
802-
rfModel <- spark.randomForest(df, Employed ~ ., type = "regression", maxDepth = 2, numTrees = 2)
818+
```{r}
819+
t <- as.data.frame(Titanic)
820+
df <- createDataFrame(t)
821+
rfModel <- spark.randomForest(df, Survived ~ ., type = "classification", maxDepth = 2, numTrees = 2)
803822
summary(rfModel)
804823
predictions <- predict(rfModel, df)
805824
```
@@ -965,17 +984,18 @@ Given a `SparkDataFrame`, the test compares continuous data in a given column `t
965984
specified by parameter `nullHypothesis`.
966985
Users can call `summary` to get a summary of the test results.
967986

968-
In the following example, we test whether the `longley` dataset's `Armed_Forces` column
987+
In the following example, we test whether the `Titanic` dataset's `Freq` column
969988
follows a normal distribution. We set the parameters of the normal distribution using
970989
the mean and standard deviation of the sample.
971990

972-
```{r, warning=FALSE}
973-
df <- createDataFrame(longley)
974-
afStats <- head(select(df, mean(df$Armed_Forces), sd(df$Armed_Forces)))
975-
afMean <- afStats[1]
976-
afStd <- afStats[2]
991+
```{r}
992+
t <- as.data.frame(Titanic)
993+
df <- createDataFrame(t)
994+
freqStats <- head(select(df, mean(df$Freq), sd(df$Freq)))
995+
freqMean <- freqStats[1]
996+
freqStd <- freqStats[2]
977997
978-
test <- spark.kstest(df, "Armed_Forces", "norm", c(afMean, afStd))
998+
test <- spark.kstest(df, "Freq", "norm", c(freqMean, freqStd))
979999
testSummary <- summary(test)
9801000
testSummary
9811001
```

common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
import java.util.concurrent.ConcurrentHashMap;
2424
import java.util.concurrent.atomic.AtomicLong;
2525

26+
import scala.Tuple2;
27+
2628
import com.google.common.base.Preconditions;
2729
import io.netty.channel.Channel;
2830
import org.slf4j.Logger;
@@ -94,6 +96,25 @@ public ManagedBuffer getChunk(long streamId, int chunkIndex) {
9496
return nextChunk;
9597
}
9698

99+
@Override
100+
public ManagedBuffer openStream(String streamChunkId) {
101+
Tuple2<Long, Integer> streamIdAndChunkId = parseStreamChunkId(streamChunkId);
102+
return getChunk(streamIdAndChunkId._1, streamIdAndChunkId._2);
103+
}
104+
105+
public static String genStreamChunkId(long streamId, int chunkId) {
106+
return String.format("%d_%d", streamId, chunkId);
107+
}
108+
109+
public static Tuple2<Long, Integer> parseStreamChunkId(String streamChunkId) {
110+
String[] array = streamChunkId.split("_");
111+
assert array.length == 2:
112+
"Stream id and chunk index should be specified when open stream for fetching block.";
113+
long streamId = Long.valueOf(array[0]);
114+
int chunkIndex = Integer.valueOf(array[1]);
115+
return new Tuple2<>(streamId, chunkIndex);
116+
}
117+
97118
@Override
98119
public void connectionTerminated(Channel channel) {
99120
// Close all streams which have been associated with the channel.

common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
package org.apache.spark.network.shuffle;
1919

20+
import java.io.File;
2021
import java.io.IOException;
2122
import java.nio.ByteBuffer;
2223
import java.util.List;
@@ -86,14 +87,16 @@ public void fetchBlocks(
8687
int port,
8788
String execId,
8889
String[] blockIds,
89-
BlockFetchingListener listener) {
90+
BlockFetchingListener listener,
91+
File[] shuffleFiles) {
9092
checkInit();
9193
logger.debug("External shuffle fetch from {}:{} (executor id {})", host, port, execId);
9294
try {
9395
RetryingBlockFetcher.BlockFetchStarter blockFetchStarter =
9496
(blockIds1, listener1) -> {
9597
TransportClient client = clientFactory.createClient(host, port);
96-
new OneForOneBlockFetcher(client, appId, execId, blockIds1, listener1).start();
98+
new OneForOneBlockFetcher(client, appId, execId, blockIds1, listener1, conf,
99+
shuffleFiles).start();
97100
};
98101

99102
int maxRetries = conf.maxIORetries();

common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java

Lines changed: 60 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,19 +17,28 @@
1717

1818
package org.apache.spark.network.shuffle;
1919

20+
import java.io.File;
21+
import java.io.FileOutputStream;
22+
import java.io.IOException;
2023
import java.nio.ByteBuffer;
24+
import java.nio.channels.Channels;
25+
import java.nio.channels.WritableByteChannel;
2126
import java.util.Arrays;
2227

2328
import org.slf4j.Logger;
2429
import org.slf4j.LoggerFactory;
2530

31+
import org.apache.spark.network.buffer.FileSegmentManagedBuffer;
2632
import org.apache.spark.network.buffer.ManagedBuffer;
2733
import org.apache.spark.network.client.ChunkReceivedCallback;
2834
import org.apache.spark.network.client.RpcResponseCallback;
35+
import org.apache.spark.network.client.StreamCallback;
2936
import org.apache.spark.network.client.TransportClient;
37+
import org.apache.spark.network.server.OneForOneStreamManager;
3038
import org.apache.spark.network.shuffle.protocol.BlockTransferMessage;
3139
import org.apache.spark.network.shuffle.protocol.OpenBlocks;
3240
import org.apache.spark.network.shuffle.protocol.StreamHandle;
41+
import org.apache.spark.network.util.TransportConf;
3342

3443
/**
3544
* Simple wrapper on top of a TransportClient which interprets each chunk as a whole block, and
@@ -48,6 +57,8 @@ public class OneForOneBlockFetcher {
4857
private final String[] blockIds;
4958
private final BlockFetchingListener listener;
5059
private final ChunkReceivedCallback chunkCallback;
60+
private TransportConf transportConf = null;
61+
private File[] shuffleFiles = null;
5162

5263
private StreamHandle streamHandle = null;
5364

@@ -56,12 +67,20 @@ public OneForOneBlockFetcher(
5667
String appId,
5768
String execId,
5869
String[] blockIds,
59-
BlockFetchingListener listener) {
70+
BlockFetchingListener listener,
71+
TransportConf transportConf,
72+
File[] shuffleFiles) {
6073
this.client = client;
6174
this.openMessage = new OpenBlocks(appId, execId, blockIds);
6275
this.blockIds = blockIds;
6376
this.listener = listener;
6477
this.chunkCallback = new ChunkCallback();
78+
this.transportConf = transportConf;
79+
if (shuffleFiles != null) {
80+
this.shuffleFiles = shuffleFiles;
81+
assert this.shuffleFiles.length == blockIds.length:
82+
"Number of shuffle files should equal to blocks";
83+
}
6584
}
6685

6786
/** Callback invoked on receipt of each chunk. We equate a single chunk to a single block. */
@@ -100,7 +119,12 @@ public void onSuccess(ByteBuffer response) {
100119
// Immediately request all chunks -- we expect that the total size of the request is
101120
// reasonable due to higher level chunking in [[ShuffleBlockFetcherIterator]].
102121
for (int i = 0; i < streamHandle.numChunks; i++) {
103-
client.fetchChunk(streamHandle.streamId, i, chunkCallback);
122+
if (shuffleFiles != null) {
123+
client.stream(OneForOneStreamManager.genStreamChunkId(streamHandle.streamId, i),
124+
new DownloadCallback(shuffleFiles[i], i));
125+
} else {
126+
client.fetchChunk(streamHandle.streamId, i, chunkCallback);
127+
}
104128
}
105129
} catch (Exception e) {
106130
logger.error("Failed while starting block fetches after success", e);
@@ -126,4 +150,38 @@ private void failRemainingBlocks(String[] failedBlockIds, Throwable e) {
126150
}
127151
}
128152
}
153+
154+
private class DownloadCallback implements StreamCallback {
155+
156+
private WritableByteChannel channel = null;
157+
private File targetFile = null;
158+
private int chunkIndex;
159+
160+
public DownloadCallback(File targetFile, int chunkIndex) throws IOException {
161+
this.targetFile = targetFile;
162+
this.channel = Channels.newChannel(new FileOutputStream(targetFile));
163+
this.chunkIndex = chunkIndex;
164+
}
165+
166+
@Override
167+
public void onData(String streamId, ByteBuffer buf) throws IOException {
168+
channel.write(buf);
169+
}
170+
171+
@Override
172+
public void onComplete(String streamId) throws IOException {
173+
channel.close();
174+
ManagedBuffer buffer = new FileSegmentManagedBuffer(transportConf, targetFile, 0,
175+
targetFile.length());
176+
listener.onBlockFetchSuccess(blockIds[chunkIndex], buffer);
177+
}
178+
179+
@Override
180+
public void onFailure(String streamId, Throwable cause) throws IOException {
181+
channel.close();
182+
// On receipt of a failure, fail every block from chunkIndex onwards.
183+
String[] remainingBlockIds = Arrays.copyOfRange(blockIds, chunkIndex, blockIds.length);
184+
failRemainingBlocks(remainingBlockIds, cause);
185+
}
186+
}
129187
}

common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.network.shuffle;
1919

2020
import java.io.Closeable;
21+
import java.io.File;
2122

2223
/** Provides an interface for reading shuffle files, either from an Executor or external service. */
2324
public abstract class ShuffleClient implements Closeable {
@@ -40,5 +41,6 @@ public abstract void fetchBlocks(
4041
int port,
4142
String execId,
4243
String[] blockIds,
43-
BlockFetchingListener listener);
44+
BlockFetchingListener listener,
45+
File[] shuffleFiles);
4446
}

common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ public void onBlockFetchFailure(String blockId, Throwable t) {
204204

205205
String[] blockIds = { "shuffle_2_3_4", "shuffle_6_7_8" };
206206
OneForOneBlockFetcher fetcher =
207-
new OneForOneBlockFetcher(client1, "app-2", "0", blockIds, listener);
207+
new OneForOneBlockFetcher(client1, "app-2", "0", blockIds, listener, conf, null);
208208
fetcher.start();
209209
blockFetchLatch.await();
210210
checkSecurityException(exception.get());

common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ public void onBlockFetchFailure(String blockId, Throwable exception) {
158158
}
159159
}
160160
}
161-
});
161+
}, null);
162162

163163
if (!requestsRemaining.tryAcquire(blockIds.length, 5, TimeUnit.SECONDS)) {
164164
fail("Timeout getting response from the server");

common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,13 @@
4646
import org.apache.spark.network.shuffle.protocol.BlockTransferMessage;
4747
import org.apache.spark.network.shuffle.protocol.OpenBlocks;
4848
import org.apache.spark.network.shuffle.protocol.StreamHandle;
49+
import org.apache.spark.network.util.MapConfigProvider;
50+
import org.apache.spark.network.util.TransportConf;
4951

5052
public class OneForOneBlockFetcherSuite {
53+
54+
private static final TransportConf conf = new TransportConf("shuffle", MapConfigProvider.EMPTY);
55+
5156
@Test
5257
public void testFetchOne() {
5358
LinkedHashMap<String, ManagedBuffer> blocks = Maps.newLinkedHashMap();
@@ -126,7 +131,7 @@ private static BlockFetchingListener fetchBlocks(LinkedHashMap<String, ManagedBu
126131
BlockFetchingListener listener = mock(BlockFetchingListener.class);
127132
String[] blockIds = blocks.keySet().toArray(new String[blocks.size()]);
128133
OneForOneBlockFetcher fetcher =
129-
new OneForOneBlockFetcher(client, "app-id", "exec-id", blockIds, listener);
134+
new OneForOneBlockFetcher(client, "app-id", "exec-id", blockIds, listener, conf, null);
130135

131136
// Respond to the "OpenBlocks" message with an appropriate ShuffleStreamHandle with streamId 123
132137
doAnswer(invocationOnMock -> {

core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import javax.annotation.concurrent.GuardedBy;
2121
import java.io.IOException;
22+
import java.nio.channels.ClosedByInterruptException;
2223
import java.util.Arrays;
2324
import java.util.ArrayList;
2425
import java.util.BitSet;
@@ -184,6 +185,10 @@ public long acquireExecutionMemory(long required, MemoryConsumer consumer) {
184185
break;
185186
}
186187
}
188+
} catch (ClosedByInterruptException e) {
189+
// This called by user to kill a task (e.g: speculative task).
190+
logger.error("error while calling spill() on " + c, e);
191+
throw new RuntimeException(e.getMessage());
187192
} catch (IOException e) {
188193
logger.error("error while calling spill() on " + c, e);
189194
throw new OutOfMemoryError("error while calling spill() on " + c + " : "
@@ -201,6 +206,10 @@ public long acquireExecutionMemory(long required, MemoryConsumer consumer) {
201206
Utils.bytesToString(released), consumer);
202207
got += memoryManager.acquireExecutionMemory(required - got, taskAttemptId, mode);
203208
}
209+
} catch (ClosedByInterruptException e) {
210+
// This called by user to kill a task (e.g: speculative task).
211+
logger.error("error while calling spill() on " + consumer, e);
212+
throw new RuntimeException(e.getMessage());
204213
} catch (IOException e) {
205214
logger.error("error while calling spill() on " + consumer, e);
206215
throw new OutOfMemoryError("error while calling spill() on " + consumer + " : "

core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -422,17 +422,14 @@ private long[] mergeSpillsWithTransferTo(SpillInfo[] spills, File outputFile) th
422422
for (int partition = 0; partition < numPartitions; partition++) {
423423
for (int i = 0; i < spills.length; i++) {
424424
final long partitionLengthInSpill = spills[i].partitionLengths[partition];
425-
long bytesToTransfer = partitionLengthInSpill;
426425
final FileChannel spillInputChannel = spillInputChannels[i];
427426
final long writeStartTime = System.nanoTime();
428-
while (bytesToTransfer > 0) {
429-
final long actualBytesTransferred = spillInputChannel.transferTo(
430-
spillInputChannelPositions[i],
431-
bytesToTransfer,
432-
mergedFileOutputChannel);
433-
spillInputChannelPositions[i] += actualBytesTransferred;
434-
bytesToTransfer -= actualBytesTransferred;
435-
}
427+
Utils.copyFileStreamNIO(
428+
spillInputChannel,
429+
mergedFileOutputChannel,
430+
spillInputChannelPositions[i],
431+
partitionLengthInSpill);
432+
spillInputChannelPositions[i] += partitionLengthInSpill;
436433
writeMetrics.incWriteTime(System.nanoTime() - writeStartTime);
437434
bytesWrittenToMergedFile += partitionLengthInSpill;
438435
partitionLengths[partition] += partitionLengthInSpill;

0 commit comments

Comments
 (0)