Skip to content

Commit 11ba490

Browse files
committed
Add draft implementation of %python.sql for DataFrames
1 parent bd714c2 commit 11ba490

File tree

4 files changed

+245
-1
lines changed

4 files changed

+245
-1
lines changed

python/pom.xml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,10 @@
3535

3636
<properties>
3737
<py4j.version>0.9.2</py4j.version>
38-
<python.test.exclude>**/PythonInterpreterWithPythonInstalledTest.java</python.test.exclude>
38+
<python.test.exclude>
39+
**/PythonInterpreterWithPythonInstalledTest.java,
40+
**/PythonPandasSqlInterpreterTest.java
41+
</python.test.exclude>
3942
</properties>
4043

4144
<dependencies>
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.zeppelin.python;
19+
20+
import java.io.IOException;
21+
import java.util.Properties;
22+
23+
import org.apache.zeppelin.interpreter.Interpreter;
24+
import org.apache.zeppelin.interpreter.InterpreterContext;
25+
import org.apache.zeppelin.interpreter.InterpreterResult;
26+
import org.slf4j.Logger;
27+
import org.slf4j.LoggerFactory;
28+
29+
/**
30+
* SQL over Pandas DataFrame interpreter for %python group
31+
*
32+
* Match experience of %sparpk.sql over Spark DataFrame
33+
*/
34+
public class PythonPandasSqlInterpreter extends Interpreter {
35+
private static final Logger LOG = LoggerFactory.getLogger(PythonPandasSqlInterpreter.class);
36+
37+
private String SQL_BOOTSTRAP_FILE_PY = "/bootstrap_sql.py";
38+
39+
public PythonPandasSqlInterpreter(Properties property) {
40+
super(property);
41+
}
42+
43+
@Override
44+
public void open() {
45+
LOG.info("Open Python SQL interpreter instance: {}", this.toString());
46+
47+
//TODO(bzz): check by importing and catching ImportError
48+
//if (pandasAndNumpyAndPandasqlAreInstalled) {
49+
try {
50+
LOG.info("Bootstrap {} interpreter with {}", this.toString(), SQL_BOOTSTRAP_FILE_PY);
51+
PythonInterpreter python = (PythonInterpreter) this.getInterpreterInTheSameSessionByClassName(
52+
PythonInterpreter.class.getName());
53+
python.bootStrapInterpreter(SQL_BOOTSTRAP_FILE_PY);
54+
} catch (IOException e) {
55+
LOG.error("Can't execute " + SQL_BOOTSTRAP_FILE_PY + " to import SQL dependencies", e);
56+
}
57+
//}
58+
}
59+
60+
@Override
61+
public void close() {
62+
LOG.info("Close Python SQL interpreter instance: {}", this.toString());
63+
}
64+
65+
@Override
66+
public InterpreterResult interpret(String st, InterpreterContext context) {
67+
LOG.info("Running SQL query: '{}' over Pandas DataFrame", st);
68+
Interpreter python = this.getInterpreterInTheSameSessionByClassName(
69+
PythonInterpreter.class.getName());
70+
return python.interpret("print pysqldf('" + st + "')", context);
71+
}
72+
73+
@Override
74+
public void cancel(InterpreterContext context) {
75+
76+
}
77+
78+
@Override
79+
public FormType getFormType() {
80+
return null;
81+
}
82+
83+
@Override
84+
public int getProgress(InterpreterContext context) {
85+
return 0;
86+
}
87+
88+
}
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one or more
2+
# contributor license agreements. See the NOTICE file distributed with
3+
# this work for additional information regarding copyright ownership.
4+
# The ASF licenses this file to You under the Apache License, Version 2.0
5+
# (the "License"); you may not use this file except in compliance with
6+
# the License. You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
# This is for org.apache.zeppelin.python.PythonPandasSqlInterpreterTest
17+
# It requires next dependencies to be installed:
18+
# - numpy
19+
# - pandas
20+
# - pandasql
21+
22+
23+
import numpy as np
24+
import pandas as pd
25+
from pandasql import sqldf
26+
27+
pysqldf = lambda q: sqldf(q, globals())
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.zeppelin.python;
19+
20+
import static org.junit.Assert.assertEquals;
21+
import static org.junit.Assert.assertNotNull;
22+
import static org.junit.Assert.assertTrue;
23+
24+
import java.util.Arrays;
25+
import java.util.HashMap;
26+
import java.util.LinkedList;
27+
import java.util.Properties;
28+
29+
import org.apache.zeppelin.display.AngularObjectRegistry;
30+
import org.apache.zeppelin.display.GUI;
31+
import org.apache.zeppelin.interpreter.InterpreterContext;
32+
import org.apache.zeppelin.interpreter.InterpreterContextRunner;
33+
import org.apache.zeppelin.interpreter.InterpreterGroup;
34+
import org.apache.zeppelin.interpreter.InterpreterOutput;
35+
import org.apache.zeppelin.interpreter.InterpreterOutputListener;
36+
import org.apache.zeppelin.interpreter.InterpreterResult;
37+
import org.apache.zeppelin.interpreter.InterpreterResult.Type;
38+
import org.apache.zeppelin.user.AuthenticationInfo;
39+
40+
import org.junit.Before;
41+
import org.junit.Test;
42+
43+
/**
44+
* In order for this test to work, test env must have installed:
45+
* <ol>
46+
* - <li>Python</li>
47+
* - <li>NumPy</li>
48+
* - <li>Pandas DataFrame</li>
49+
* <ol>
50+
*
51+
* To run manually on such environment, use:
52+
* <code>
53+
* mvn "-Dtest=org.apache.zeppelin.python.PythonPandasSqlInterpreterTest" test -pl python
54+
* </code>
55+
*/
56+
public class PythonPandasSqlInterpreterTest {
57+
58+
private InterpreterGroup intpGroup;
59+
private PythonPandasSqlInterpreter sql;
60+
private PythonInterpreter python;
61+
62+
private InterpreterContext context;
63+
64+
@Before
65+
public void setUp() throws Exception {
66+
Properties p = new Properties();
67+
p.setProperty("zeppelin.python", "python");
68+
p.setProperty("zeppelin.python.maxResult", "100");
69+
70+
intpGroup = new InterpreterGroup();
71+
72+
python = new PythonInterpreter(p);
73+
python.setInterpreterGroup(intpGroup);
74+
python.open();
75+
76+
sql = new PythonPandasSqlInterpreter(p);
77+
sql.setInterpreterGroup(intpGroup);
78+
79+
intpGroup.put("note", Arrays.asList(python, sql));
80+
81+
context = new InterpreterContext("note", "id", "title", "text", new AuthenticationInfo(),
82+
new HashMap<String, Object>(), new GUI(),
83+
new AngularObjectRegistry(intpGroup.getId(), null), null,
84+
new LinkedList<InterpreterContextRunner>(), new InterpreterOutput(
85+
new InterpreterOutputListener() {
86+
@Override public void onAppend(InterpreterOutput out, byte[] line) {}
87+
@Override public void onUpdate(InterpreterOutput out, byte[] output) {}
88+
}));
89+
90+
//important to be last step
91+
sql.open();
92+
//it depends on python interpreter presence in the same group
93+
}
94+
95+
//@Test
96+
public void sqlOverTestDataPrintsTable() {
97+
//given
98+
// `import pandas as pd` and `import numpy as np` done
99+
// DataFrame \w test data
100+
python.interpret("df2 = pd.DataFrame({ 'age' : np.array([33, 51, 51, 34]), "+
101+
"'name' : pd.Categorical(['moon','jobs','gates','park'])})", context);
102+
103+
104+
//when
105+
InterpreterResult ret = sql.interpret("select name, age from test where age < 40", context);
106+
assertEquals(InterpreterResult.Code.SUCCESS, ret.code());
107+
assertEquals(Type.TABLE, ret.type());
108+
assertEquals("name\tage\nmoon\t33\npark\t34\n", ret.message());
109+
110+
assertEquals(InterpreterResult.Code.SUCCESS, sql.interpret("select case when name==\"aa\" then name else name end from test", context).code());
111+
}
112+
113+
@Test
114+
public void badSqlSyntaxFails() {
115+
//when
116+
InterpreterResult ret = sql.interpret("select wrong syntax", context);
117+
118+
//then
119+
assertNotNull("Interpreter returned 'null'", ret);
120+
//System.out.println("\nInterpreter response: \n" + ret.message());
121+
assertEquals(InterpreterResult.Code.ERROR, ret.code());
122+
assertTrue(ret.message().length() > 0);
123+
}
124+
125+
126+
}

0 commit comments

Comments
 (0)