Skip to content

Commit b9e0e1e

Browse files
Leemoonsoolresende
authored andcommitted
scala 2.11 support for spark interpreter
1 parent c88348d commit b9e0e1e

File tree

3 files changed

+107
-35
lines changed

3 files changed

+107
-35
lines changed

spark/src/main/java/org/apache/zeppelin/spark/DepInterpreter.java

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,11 @@
5252
import scala.None;
5353
import scala.Some;
5454
import scala.collection.convert.WrapAsJava$;
55+
import scala.collection.JavaConversions;
5556
import scala.tools.nsc.Settings;
5657
import scala.tools.nsc.interpreter.Completion.Candidates;
5758
import scala.tools.nsc.interpreter.Completion.ScalaCompleter;
59+
import scala.tools.nsc.interpreter.IMain;
5860
import scala.tools.nsc.interpreter.Results;
5961
import scala.tools.nsc.settings.MutableSettings.BooleanSetting;
6062
import scala.tools.nsc.settings.MutableSettings.PathSetting;
@@ -180,7 +182,12 @@ private void createIMain() {
180182
new Object[]{intp});
181183
}
182184
interpret("@transient var _binder = new java.util.HashMap[String, Object]()");
183-
Map<String, Object> binder = (Map<String, Object>) getValue("_binder");
185+
Map<String, Object> binder;
186+
if (isScala2_10()) {
187+
binder = (Map<String, Object>) getValue("_binder");
188+
} else {
189+
binder = (Map<String, Object>) getLastObject();
190+
}
184191
binder.put("depc", depc);
185192

186193
interpret("@transient val z = "
@@ -208,6 +215,13 @@ public Object getValue(String name) {
208215
}
209216
}
210217

218+
public Object getLastObject() {
219+
IMain.Request r = (IMain.Request) invokeMethod(intp, "lastRequest");
220+
Object obj = r.lineRep().call("$result",
221+
JavaConversions.asScalaBuffer(new LinkedList<Object>()));
222+
return obj;
223+
}
224+
211225
@Override
212226
public InterpreterResult interpret(String st, InterpreterContext context) {
213227
PrintStream printStream = new PrintStream(out);
@@ -285,7 +299,6 @@ public List<InterpreterCompletion> completion(String buf, int cursor) {
285299
} else {
286300
return new LinkedList<InterpreterCompletion>();
287301
}
288-
289302
}
290303

291304
private List<File> currentClassPath() {

spark/src/main/java/org/apache/zeppelin/spark/SparkInterpreter.java

Lines changed: 88 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,14 @@
3939
import org.apache.spark.SparkContext;
4040
import org.apache.spark.SparkEnv;
4141

42+
import org.apache.spark.SecurityManager;
4243
import org.apache.spark.repl.SparkILoop;
4344
import org.apache.spark.scheduler.ActiveJob;
4445
import org.apache.spark.scheduler.DAGScheduler;
4546
import org.apache.spark.scheduler.Pool;
4647
import org.apache.spark.sql.SQLContext;
4748
import org.apache.spark.ui.jobs.JobProgressListener;
49+
import org.apache.spark.util.Utils;
4850
import org.apache.zeppelin.interpreter.Interpreter;
4951
import org.apache.zeppelin.interpreter.InterpreterContext;
5052
import org.apache.zeppelin.interpreter.InterpreterException;
@@ -78,6 +80,7 @@
7880
import scala.tools.nsc.Settings;
7981
import scala.tools.nsc.interpreter.Completion.Candidates;
8082
import scala.tools.nsc.interpreter.Completion.ScalaCompleter;
83+
import scala.tools.nsc.interpreter.IMain;
8184
import scala.tools.nsc.interpreter.Results;
8285
import scala.tools.nsc.settings.MutableSettings;
8386
import scala.tools.nsc.settings.MutableSettings.BooleanSetting;
@@ -115,6 +118,8 @@ public class SparkInterpreter extends Interpreter {
115118

116119
private Map<String, Object> binder;
117120
private SparkVersion sparkVersion;
121+
private File outputDir; // class outputdir for scala 2.11
122+
private HttpServer classServer; // classserver for scala 2.11
118123

119124

120125
public SparkInterpreter(Properties property) {
@@ -282,6 +287,19 @@ public SparkContext createSparkContext() {
282287
}
283288
}
284289

290+
291+
if (isScala2_11()) {
292+
SparkConf conf = new SparkConf();
293+
classServer = new HttpServer(
294+
conf,
295+
outputDir,
296+
new SecurityManager(conf),
297+
0,
298+
"HTTP server");
299+
classServer.start();
300+
classServerUri = classServer.uri();
301+
}
302+
285303
SparkConf conf =
286304
new SparkConf()
287305
.setMaster(getProperty("master"))
@@ -413,26 +431,49 @@ public void open() {
413431
* getClass.getClassLoader >> } >> in.setContextClassLoader()
414432
*/
415433
Settings settings = new Settings();
416-
if (getProperty("args") != null) {
417-
String[] argsArray = getProperty("args").split(" ");
418-
LinkedList<String> argList = new LinkedList<String>();
419-
for (String arg : argsArray) {
420-
argList.add(arg);
421-
}
422434

435+
// process args
436+
String args = getProperty("args");
437+
if (args == null) {
438+
args = "";
439+
}
440+
441+
String[] argsArray = args.split(" ");
442+
LinkedList<String> argList = new LinkedList<String>();
443+
for (String arg : argsArray) {
444+
argList.add(arg);
445+
}
446+
447+
if (isScala2_10()) {
423448
scala.collection.immutable.List<String> list =
424449
JavaConversions.asScalaBuffer(argList).toList();
425450

426-
if (isScala2_10()) {
427-
Object sparkCommandLine = instantiateClass(
428-
"org.apache.spark.repl.SparkCommandLine",
429-
new Class[]{ list.getClass() },
430-
new Object[]{ list });
451+
Object sparkCommandLine = instantiateClass(
452+
"org.apache.spark.repl.SparkCommandLine",
453+
new Class[]{ list.getClass() },
454+
new Object[]{ list });
431455

432-
settings = (Settings) invokeMethod(sparkCommandLine, "settings");
433-
} else {
434-
settings.processArguments(list, true);
456+
settings = (Settings) invokeMethod(sparkCommandLine, "settings");
457+
} else {
458+
String sparkReplClassDir = getProperty("spark.repl.classdir");
459+
if (sparkReplClassDir == null) {
460+
sparkReplClassDir = System.getProperty("spark.repl.classdir");
461+
}
462+
if (sparkReplClassDir == null) {
463+
sparkReplClassDir = System.getProperty("java.io.tmpdir");
435464
}
465+
466+
outputDir = Utils.createTempDir(sparkReplClassDir, "classdir");
467+
468+
argList.add("-Yrepl-class-based");
469+
argList.add("-Yrepl-outdir");
470+
argList.add(outputDir.getAbsolutePath());
471+
472+
473+
scala.collection.immutable.List<String> list =
474+
JavaConversions.asScalaBuffer(argList).toList();
475+
476+
settings.processArguments(list, true);
436477
}
437478

438479
// set classpath for scala compiler
@@ -526,24 +567,22 @@ public void open() {
526567
if (isScala2_10()) {
527568
invokeMethod(intp, "setContextClassLoader");
528569
invokeMethod(intp, "initializeSynchronous");
529-
}
530570

531-
if (classOutputDir == null) {
532-
classOutputDir = settings.outputDirs().getSingleOutput().get();
533-
} else {
534-
// change SparkIMain class output dir
535-
settings.outputDirs().setSingleOutput(classOutputDir);
536-
ClassLoader cl = (ClassLoader) invokeMethod(intp, "classLoader");
537-
try {
538-
Field rootField = cl.getClass().getSuperclass().getDeclaredField("root");
539-
rootField.setAccessible(true);
540-
rootField.set(cl, classOutputDir);
541-
} catch (NoSuchFieldException | IllegalAccessException e) {
542-
logger.error(e.getMessage(), e);
571+
if (classOutputDir == null) {
572+
classOutputDir = settings.outputDirs().getSingleOutput().get();
573+
} else {
574+
// change SparkIMain class output dir
575+
settings.outputDirs().setSingleOutput(classOutputDir);
576+
ClassLoader cl = (ClassLoader) invokeMethod(intp, "classLoader");
577+
try {
578+
Field rootField = cl.getClass().getSuperclass().getDeclaredField("root");
579+
rootField.setAccessible(true);
580+
rootField.set(cl, classOutputDir);
581+
} catch (NoSuchFieldException | IllegalAccessException e) {
582+
logger.error(e.getMessage(), e);
583+
}
543584
}
544-
}
545585

546-
if (isScala2_10()) {
547586
completor = instantiateClass(
548587
"SparkJLineCompletion",
549588
new Class[]{findClass("org.apache.spark.repl.SparkIMain")},
@@ -568,8 +607,8 @@ public void open() {
568607
z = new ZeppelinContext(sc, sqlc, null, dep,
569608
Integer.parseInt(getProperty("zeppelin.spark.maxResult")));
570609

571-
interpret("@transient var _binder = new java.util.HashMap[String, Object]()");
572-
binder = (Map<String, Object>) getValue("_binder");
610+
interpret("@transient val _binder = new java.util.HashMap[String, Object]()");
611+
binder = (Map<String, Object>) getLastObject();
573612
binder.put("sc", sc);
574613
binder.put("sqlc", sqlc);
575614
binder.put("z", z);
@@ -769,9 +808,14 @@ private String getCompletionTargetString(String text, int cursor) {
769808
return resultCompletionText;
770809
}
771810

811+
/*
812+
* this method doesn't work in scala 2.11
813+
* Somehow intp.valueOfTerm returns scala.None always with -Yrepl-class-based option
814+
*/
772815
public Object getValue(String name) {
773816
Object ret = invokeMethod(intp, "valueOfTerm", new Class[]{String.class}, new Object[]{name});
774-
if (ret instanceof None) {
817+
818+
if (ret instanceof None || ret instanceof scala.None$) {
775819
return null;
776820
} else if (ret instanceof Some) {
777821
return ((Some) ret).get();
@@ -780,6 +824,13 @@ public Object getValue(String name) {
780824
}
781825
}
782826

827+
public Object getLastObject() {
828+
IMain.Request r = (IMain.Request) invokeMethod(intp, "lastRequest");
829+
Object obj = r.lineRep().call("$result",
830+
JavaConversions.asScalaBuffer(new LinkedList<Object>()));
831+
return obj;
832+
}
833+
783834
String getJobGroup(InterpreterContext context){
784835
return "zeppelin-" + context.getParagraphId();
785836
}
@@ -1049,6 +1100,10 @@ public void close() {
10491100
if (numReferenceOfSparkContext.decrementAndGet() == 0) {
10501101
sc.stop();
10511102
sc = null;
1103+
if (classServer != null) {
1104+
classServer.stop();
1105+
classServer = null;
1106+
}
10521107
}
10531108

10541109
invokeMethod(intp, "close");
@@ -1144,7 +1199,7 @@ private boolean isScala2_10() {
11441199
}
11451200

11461201
private boolean isScala2_11() {
1147-
return !isScala2_11();
1202+
return !isScala2_10();
11481203
}
11491204

11501205

spark/src/test/java/org/apache/zeppelin/spark/SparkInterpreterTest.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,15 @@
1919

2020
import static org.junit.Assert.*;
2121

22+
import java.io.BufferedReader;
2223
import java.io.File;
2324
import java.util.HashMap;
2425
import java.util.LinkedList;
2526
import java.util.Properties;
2627

2728
import org.apache.spark.SparkConf;
2829
import org.apache.spark.SparkContext;
30+
import org.apache.spark.repl.SparkILoop;
2931
import org.apache.zeppelin.display.AngularObjectRegistry;
3032
import org.apache.zeppelin.resource.LocalResourcePool;
3133
import org.apache.zeppelin.user.AuthenticationInfo;
@@ -39,6 +41,7 @@
3941
import org.junit.runners.MethodSorters;
4042
import org.slf4j.Logger;
4143
import org.slf4j.LoggerFactory;
44+
import scala.tools.nsc.interpreter.IMain;
4245

4346
@FixMethodOrder(MethodSorters.NAME_ASCENDING)
4447
public class SparkInterpreterTest {
@@ -137,6 +140,7 @@ public void testBasicIntp() {
137140
assertEquals(InterpreterResult.Code.INCOMPLETE, incomplete.code());
138141
assertTrue(incomplete.message().length() > 0); // expecting some error
139142
// message
143+
140144
/*
141145
* assertEquals(1, repl.getValue("a")); assertEquals(2, repl.getValue("b"));
142146
* repl.interpret("val ver = sc.version");

0 commit comments

Comments
 (0)