Skip to content

Commit 913f0d8

Browse files
Add support for Base64 Encode/Decode Scalar Functions (#9114)
* before testing on agg * added null case tests for encode/decode, added invalid check for decode * roll back style changes * minor changes * roll back style changes * added base64 with user encoding scheme, added nested tranfunc for single arg * minor change * tested binary to base64 * code reformat * simplied logic, editted all tests, before reformat * checked style * fixed format issue * fixed format issue * added invalid arg test for toBase64
1 parent dad2586 commit 913f0d8

File tree

5 files changed

+423
-3
lines changed

5 files changed

+423
-3
lines changed

pinot-broker/src/test/java/org/apache/pinot/broker/requesthandler/LiteralOnlyBrokerRequestTest.java

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,14 @@
2020

2121
import com.fasterxml.jackson.databind.JsonNode;
2222
import java.util.Collections;
23+
import java.util.List;
2324
import java.util.Random;
2425
import java.util.concurrent.TimeUnit;
2526
import org.apache.pinot.broker.broker.AccessControlFactory;
2627
import org.apache.pinot.broker.broker.AllowAllAccessControlFactory;
2728
import org.apache.pinot.common.metrics.BrokerMetrics;
2829
import org.apache.pinot.common.response.broker.BrokerResponseNative;
30+
import org.apache.pinot.common.response.broker.ResultTable;
2931
import org.apache.pinot.common.utils.DataSchema;
3032
import org.apache.pinot.spi.env.PinotConfiguration;
3133
import org.apache.pinot.spi.metrics.PinotMetricUtils;
@@ -104,6 +106,22 @@ public void testLiteralOnlyTransformBrokerRequestFromSQL() {
104106
Assert.assertFalse(BaseBrokerRequestHandler.isLiteralOnlyQuery(CalciteSqlParser
105107
.compileToPinotQuery("SELECT count(*) from foo "
106108
+ "where bar = decodeUrl('key1%3Dvalue+1%26key2%3Dvalue%40%21%242%26key3%3Dvalue%253')")));
109+
Assert.assertTrue(BaseBrokerRequestHandler.isLiteralOnlyQuery(CalciteSqlParser.compileToPinotQuery(
110+
"SELECT toUtf8('hello!')," + " fromUtf8(toUtf8('hello!')) FROM myTable")));
111+
Assert.assertFalse(BaseBrokerRequestHandler.isLiteralOnlyQuery(CalciteSqlParser.compileToPinotQuery(
112+
"SELECT reverse(fromUtf8(foo))," + " toUtf8('hello!') FROM myTable")));
113+
Assert.assertTrue(BaseBrokerRequestHandler.isLiteralOnlyQuery(CalciteSqlParser.compileToPinotQuery(
114+
"SELECT toBase64(toUtf8('hello!'))," + " fromBase64('aGVsbG8h') FROM myTable")));
115+
Assert.assertFalse(BaseBrokerRequestHandler.isLiteralOnlyQuery(CalciteSqlParser.compileToPinotQuery(
116+
"SELECT reverse(toBase64(foo))," + " toBase64(fromBase64('aGVsbG8h')) FROM myTable")));
117+
Assert.assertFalse(BaseBrokerRequestHandler.isLiteralOnlyQuery(
118+
CalciteSqlParser.compileToPinotQuery("SELECT fromBase64(toBase64(to_utf8(foo))) FROM myTable")));
119+
Assert.assertFalse(BaseBrokerRequestHandler.isLiteralOnlyQuery(
120+
CalciteSqlParser.compileToPinotQuery("SELECT count(*) from foo " + "where bar = toBase64(toASCII('hello!'))")));
121+
Assert.assertFalse(BaseBrokerRequestHandler.isLiteralOnlyQuery(
122+
CalciteSqlParser.compileToPinotQuery("SELECT count(*) from foo " + "where bar = fromBase64('aGVsbG8h')")));
123+
Assert.assertFalse(BaseBrokerRequestHandler.isLiteralOnlyQuery(CalciteSqlParser.compileToPinotQuery(
124+
"SELECT count(*) from foo " + "where bar = fromUtf8(fromBase64('aGVsbG8h'))")));
107125
}
108126

109127
@Test
@@ -115,6 +133,10 @@ public void testLiteralOnlyWithAsBrokerRequestFromSQL() {
115133
Assert.assertTrue(BaseBrokerRequestHandler.isLiteralOnlyQuery(CalciteSqlParser.compileToPinotQuery(
116134
"SELECT encodeUrl('key1=value 1&key2=value@!$2&key3=value%3') AS encoded, "
117135
+ "decodeUrl('key1%3Dvalue+1%26key2%3Dvalue%40%21%242%26key3%3Dvalue%253') AS decoded")));
136+
Assert.assertTrue(BaseBrokerRequestHandler.isLiteralOnlyQuery(CalciteSqlParser.compileToPinotQuery(
137+
"SELECT toUtf8('hello!') AS encoded, " + "fromUtf8(toUtf8('hello!')) AS decoded")));
138+
Assert.assertTrue(BaseBrokerRequestHandler.isLiteralOnlyQuery(CalciteSqlParser.compileToPinotQuery(
139+
"SELECT toBase64(toUtf8('hello!')) AS encoded, " + "fromBase64('aGVsbG8h') AS decoded")));
118140
}
119141

120142
@Test
@@ -211,6 +233,76 @@ public void testBrokerRequestHandlerWithAsFunction()
211233
Assert.assertEquals(brokerResponse.getResultTable().getRows().get(0)[1].toString(),
212234
"key1=value 1&key2=value@!$2&key3=value%3");
213235
Assert.assertEquals(brokerResponse.getTotalDocs(), 0);
236+
237+
request = JsonUtils.stringToJsonNode(
238+
"{\"sql\":\"SELECT toBase64(toUtf8('hello!')) AS encoded, " + "fromUtf8(fromBase64('aGVsbG8h')) AS decoded\"}");
239+
requestStats = Tracing.getTracer().createRequestScope();
240+
brokerResponse = requestHandler.handleRequest(request, null, requestStats);
241+
ResultTable resultTable = brokerResponse.getResultTable();
242+
DataSchema dataSchema = resultTable.getDataSchema();
243+
List<Object[]> rows = resultTable.getRows();
244+
Assert.assertEquals(dataSchema.getColumnName(0), "encoded");
245+
Assert.assertEquals(dataSchema.getColumnDataType(0), DataSchema.ColumnDataType.STRING);
246+
Assert.assertEquals(dataSchema.getColumnName(1), "decoded");
247+
Assert.assertEquals(dataSchema.getColumnDataType(1), DataSchema.ColumnDataType.STRING);
248+
Assert.assertEquals(rows.size(), 1);
249+
Assert.assertEquals(rows.get(0).length, 2);
250+
Assert.assertEquals(rows.get(0)[0].toString(), "aGVsbG8h");
251+
Assert.assertEquals(rows.get(0)[1].toString(), "hello!");
252+
Assert.assertEquals(brokerResponse.getTotalDocs(), 0);
253+
254+
request = JsonUtils.stringToJsonNode(
255+
"{\"sql\":\"SELECT fromUtf8(fromBase64(toBase64(toUtf8('nested')))) AS output\"}");
256+
requestStats = Tracing.getTracer().createRequestScope();
257+
brokerResponse = requestHandler.handleRequest(request, null, requestStats);
258+
resultTable = brokerResponse.getResultTable();
259+
dataSchema = resultTable.getDataSchema();
260+
rows = resultTable.getRows();
261+
Assert.assertEquals(dataSchema.getColumnName(0), "output");
262+
Assert.assertEquals(dataSchema.getColumnDataType(0), DataSchema.ColumnDataType.STRING);
263+
Assert.assertEquals(rows.size(), 1);
264+
Assert.assertEquals(rows.get(0).length, 1);
265+
Assert.assertEquals(rows.get(0)[0].toString(), "nested");
266+
Assert.assertEquals(brokerResponse.getTotalDocs(), 0);
267+
268+
request = JsonUtils.stringToJsonNode(
269+
"{\"sql\":\"SELECT toBase64(toUtf8('this is a long string that will encode to more than 76 characters using "
270+
+ "base64'))"
271+
+ " AS encoded\"}");
272+
requestStats = Tracing.getTracer().createRequestScope();
273+
brokerResponse = requestHandler.handleRequest(request, null, requestStats);
274+
resultTable = brokerResponse.getResultTable();
275+
dataSchema = resultTable.getDataSchema();
276+
rows = resultTable.getRows();
277+
Assert.assertEquals(dataSchema.getColumnName(0), "encoded");
278+
Assert.assertEquals(dataSchema.getColumnDataType(0), DataSchema.ColumnDataType.STRING);
279+
Assert.assertEquals(rows.size(), 1);
280+
Assert.assertEquals(rows.get(0).length, 1);
281+
Assert.assertEquals(rows.get(0)[0].toString(),
282+
"dGhpcyBpcyBhIGxvbmcgc3RyaW5nIHRoYXQgd2lsbCBlbmNvZGUgdG8gbW9yZSB0aGFuIDc2IGNoYXJhY3RlcnMgdXNpbmcgYmFzZTY0");
283+
Assert.assertEquals(brokerResponse.getTotalDocs(), 0);
284+
285+
request = JsonUtils.stringToJsonNode("{\"sql\":\"SELECT fromUtf8(fromBase64"
286+
+ "('dGhpcyBpcyBhIGxvbmcgc3RyaW5nIHRoYXQgd2lsbCBlbmNvZGUgdG8gbW9yZSB0aGFuIDc2IGNoYXJhY3RlcnMgdXNpbmcgYmFzZTY0"
287+
+ "')) AS decoded\"}");
288+
requestStats = Tracing.getTracer().createRequestScope();
289+
brokerResponse = requestHandler.handleRequest(request, null, requestStats);
290+
resultTable = brokerResponse.getResultTable();
291+
dataSchema = resultTable.getDataSchema();
292+
rows = resultTable.getRows();
293+
Assert.assertEquals(dataSchema.getColumnName(0), "decoded");
294+
Assert.assertEquals(dataSchema.getColumnDataType(0), DataSchema.ColumnDataType.STRING);
295+
Assert.assertEquals(rows.size(), 1);
296+
Assert.assertEquals(rows.get(0).length, 1);
297+
Assert.assertEquals(rows.get(0)[0].toString(),
298+
"this is a long string that will encode to more than 76 characters using base64");
299+
Assert.assertEquals(brokerResponse.getTotalDocs(), 0);
300+
301+
request = JsonUtils.stringToJsonNode("{\"sql\":\"SELECT fromBase64" + "(0) AS decoded\"}");
302+
requestStats = Tracing.getTracer().createRequestScope();
303+
brokerResponse = requestHandler.handleRequest(request, null, requestStats);
304+
Assert.assertTrue(
305+
brokerResponse.getProcessingExceptions().get(0).getMessage().contains("IllegalArgumentException"));
214306
}
215307

216308
/** Tests for EXPLAIN PLAN for literal only queries. */

pinot-common/src/main/java/org/apache/pinot/common/function/scalar/StringFunctions.java

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import java.net.URLEncoder;
2424
import java.nio.charset.StandardCharsets;
2525
import java.text.Normalizer;
26+
import java.util.Base64;
2627
import java.util.regex.Matcher;
2728
import java.util.regex.Pattern;
2829
import org.apache.commons.lang3.StringUtils;
@@ -393,6 +394,15 @@ public static byte[] toUtf8(String input) {
393394
return input.getBytes(StandardCharsets.UTF_8);
394395
}
395396

397+
/**
398+
* @param input bytes
399+
* @return UTF8 encoded string
400+
*/
401+
@ScalarFunction
402+
public static String fromUtf8(byte[] input) {
403+
return new String(input, StandardCharsets.UTF_8);
404+
}
405+
396406
/**
397407
* @see StandardCharsets#US_ASCII#encode(String)
398408
* @param input
@@ -560,4 +570,22 @@ public static String decodeUrl(String input)
560570
throws UnsupportedEncodingException {
561571
return URLDecoder.decode(input, StandardCharsets.UTF_8.toString());
562572
}
573+
574+
/**
575+
* @param input binary data
576+
* @return Base64 encoded String
577+
*/
578+
@ScalarFunction
579+
public static String toBase64(byte[] input) {
580+
return Base64.getEncoder().encodeToString(input);
581+
}
582+
583+
/**
584+
* @param input Base64 encoded String
585+
* @return decoded binary data
586+
*/
587+
@ScalarFunction
588+
public static byte[] fromBase64(String input) {
589+
return Base64.getDecoder().decode(input);
590+
}
563591
}

pinot-common/src/test/java/org/apache/pinot/sql/parsers/CalciteSqlCompilerTest.java

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1890,6 +1890,68 @@ public void testCompilationInvokedFunction() {
18901890
decoded = and.getOperands().get(1).getFunctionCall().getOperands().get(1).getLiteral().getStringValue();
18911891
Assert.assertEquals(encoded, "key1%3Dvalue+1%26key2%3Dvalue%40%21%242%26key3%3Dvalue%253");
18921892
Assert.assertEquals(decoded, "key1=value 1&key2=value@!$2&key3=value%3");
1893+
1894+
query = "select toBase64(toUtf8('hello!')), fromUtf8(fromBase64('aGVsbG8h')) from mytable";
1895+
pinotQuery = CalciteSqlParser.compileToPinotQuery(query);
1896+
String encodedBase64 = pinotQuery.getSelectList().get(0).getLiteral().getStringValue();
1897+
String decodedBase64 = pinotQuery.getSelectList().get(1).getLiteral().getStringValue();
1898+
Assert.assertEquals(encodedBase64, "aGVsbG8h");
1899+
Assert.assertEquals(decodedBase64, "hello!");
1900+
1901+
query = "select toBase64(fromBase64('aGVsbG8h')), fromUtf8(fromBase64(toBase64(toUtf8('hello!')))) from mytable";
1902+
pinotQuery = CalciteSqlParser.compileToPinotQuery(query);
1903+
encodedBase64 = pinotQuery.getSelectList().get(0).getLiteral().getStringValue();
1904+
decodedBase64 = pinotQuery.getSelectList().get(1).getLiteral().getStringValue();
1905+
Assert.assertEquals(encodedBase64, "aGVsbG8h");
1906+
Assert.assertEquals(decodedBase64, "hello!");
1907+
1908+
query =
1909+
"select toBase64(toUtf8(upper('hello!'))), fromUtf8(fromBase64(toBase64(toUtf8(upper('hello!'))))) from "
1910+
+ "mytable";
1911+
pinotQuery = CalciteSqlParser.compileToPinotQuery(query);
1912+
encodedBase64 = pinotQuery.getSelectList().get(0).getLiteral().getStringValue();
1913+
decodedBase64 = pinotQuery.getSelectList().get(1).getLiteral().getStringValue();
1914+
Assert.assertEquals(encodedBase64, "SEVMTE8h");
1915+
Assert.assertEquals(decodedBase64, "HELLO!");
1916+
1917+
query =
1918+
"select reverse(fromUtf8(fromBase64(toBase64(toUtf8(upper('hello!')))))) from mytable where fromUtf8"
1919+
+ "(fromBase64(toBase64(toUtf8(upper('hello!')))))"
1920+
+ " = bar";
1921+
pinotQuery = CalciteSqlParser.compileToPinotQuery(query);
1922+
String arg1 = pinotQuery.getSelectList().get(0).getLiteral().getStringValue();
1923+
String leftOp =
1924+
pinotQuery.getFilterExpression().getFunctionCall().getOperands().get(1).getLiteral().getStringValue();
1925+
Assert.assertEquals(arg1, "!OLLEH");
1926+
Assert.assertEquals(leftOp, "HELLO!");
1927+
1928+
query = "select a from mytable where foo = toBase64(toUtf8('hello!')) and bar = fromUtf8(fromBase64('aGVsbG8h'))";
1929+
pinotQuery = CalciteSqlParser.compileToPinotQuery(query);
1930+
and = pinotQuery.getFilterExpression().getFunctionCall();
1931+
encoded = and.getOperands().get(0).getFunctionCall().getOperands().get(1).getLiteral().getStringValue();
1932+
decoded = and.getOperands().get(1).getFunctionCall().getOperands().get(1).getLiteral().getStringValue();
1933+
Assert.assertEquals(encoded, "aGVsbG8h");
1934+
Assert.assertEquals(decoded, "hello!");
1935+
1936+
query = "select fromBase64('hello') from mytable";
1937+
Exception expectedError = null;
1938+
try {
1939+
CalciteSqlParser.compileToPinotQuery(query);
1940+
} catch (Exception e) {
1941+
expectedError = e;
1942+
}
1943+
Assert.assertNotNull(expectedError);
1944+
Assert.assertTrue(expectedError instanceof SqlCompilationException);
1945+
1946+
query = "select toBase64('hello!') from mytable";
1947+
expectedError = null;
1948+
try {
1949+
CalciteSqlParser.compileToPinotQuery(query);
1950+
} catch (Exception e) {
1951+
expectedError = e;
1952+
}
1953+
Assert.assertNotNull(expectedError);
1954+
Assert.assertTrue(expectedError instanceof SqlCompilationException);
18931955
}
18941956

18951957
@Test
@@ -2012,6 +2074,40 @@ public void testCompileTimeExpression() {
20122074
Assert.assertNotNull(expression.getFunctionCall());
20132075
Assert.assertEquals(expression.getFunctionCall().getOperator(), "count");
20142076
Assert.assertEquals(expression.getFunctionCall().getOperands().get(0).getIdentifier().getName(), "*");
2077+
2078+
expression = CalciteSqlParser.compileToExpression("toBase64(toUtf8('hello!'))");
2079+
Assert.assertNotNull(expression.getFunctionCall());
2080+
pinotQuery.setFilterExpression(expression);
2081+
pinotQuery = compileTimeFunctionsInvoker.rewrite(pinotQuery);
2082+
expression = pinotQuery.getFilterExpression();
2083+
Assert.assertNotNull(expression.getLiteral());
2084+
Assert.assertEquals(expression.getLiteral().getFieldValue(), "aGVsbG8h");
2085+
2086+
expression = CalciteSqlParser.compileToExpression("fromUtf8(fromBase64('aGVsbG8h'))");
2087+
Assert.assertNotNull(expression.getFunctionCall());
2088+
pinotQuery.setFilterExpression(expression);
2089+
pinotQuery = compileTimeFunctionsInvoker.rewrite(pinotQuery);
2090+
expression = pinotQuery.getFilterExpression();
2091+
Assert.assertNotNull(expression.getLiteral());
2092+
Assert.assertEquals(expression.getLiteral().getFieldValue(), "hello!");
2093+
2094+
expression = CalciteSqlParser.compileToExpression("fromBase64(foo)");
2095+
Assert.assertNotNull(expression.getFunctionCall());
2096+
pinotQuery.setFilterExpression(expression);
2097+
pinotQuery = compileTimeFunctionsInvoker.rewrite(pinotQuery);
2098+
expression = pinotQuery.getFilterExpression();
2099+
Assert.assertNotNull(expression.getFunctionCall());
2100+
Assert.assertEquals(expression.getFunctionCall().getOperator(), "frombase64");
2101+
Assert.assertEquals(expression.getFunctionCall().getOperands().get(0).getIdentifier().getName(), "foo");
2102+
2103+
expression = CalciteSqlParser.compileToExpression("toBase64(foo)");
2104+
Assert.assertNotNull(expression.getFunctionCall());
2105+
pinotQuery.setFilterExpression(expression);
2106+
pinotQuery = compileTimeFunctionsInvoker.rewrite(pinotQuery);
2107+
expression = pinotQuery.getFilterExpression();
2108+
Assert.assertNotNull(expression.getFunctionCall());
2109+
Assert.assertEquals(expression.getFunctionCall().getOperator(), "tobase64");
2110+
Assert.assertEquals(expression.getFunctionCall().getOperands().get(0).getIdentifier().getName(), "foo");
20152111
}
20162112

20172113
@Test

pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/ScalarTransformFunctionWrapperTest.java

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import java.nio.charset.StandardCharsets;
2424
import java.text.Normalizer;
2525
import java.util.Arrays;
26+
import java.util.Base64;
2627
import org.apache.commons.codec.digest.DigestUtils;
2728
import org.apache.commons.lang3.ArrayUtils;
2829
import org.apache.commons.lang3.StringUtils;
@@ -819,4 +820,27 @@ public void testConcatStringTransformFunction() {
819820
}
820821
testTransformFunctionMV(transformFunction, expectedValues);
821822
}
823+
824+
@Test
825+
public void testBase64TransformFunction() {
826+
ExpressionContext expression = RequestContextUtils.getExpression(String.format("toBase64(%s)", BYTES_SV_COLUMN));
827+
TransformFunction transformFunction = TransformFunctionFactory.get(expression, _dataSourceMap);
828+
assertTrue(transformFunction instanceof ScalarTransformFunctionWrapper);
829+
assertEquals(transformFunction.getName(), "toBase64");
830+
String[] expectedValues = new String[NUM_ROWS];
831+
for (int i = 0; i < NUM_ROWS; i++) {
832+
expectedValues[i] = Base64.getEncoder().encodeToString(_bytesSVValues[i]);
833+
}
834+
testTransformFunction(transformFunction, expectedValues);
835+
836+
expression = RequestContextUtils.getExpression(String.format("fromBase64(toBase64(%s))", BYTES_SV_COLUMN));
837+
transformFunction = TransformFunctionFactory.get(expression, _dataSourceMap);
838+
assertTrue(transformFunction instanceof ScalarTransformFunctionWrapper);
839+
assertEquals(transformFunction.getName(), "fromBase64");
840+
byte[][] expectedBinaryValues = new byte[NUM_ROWS][];
841+
for (int i = 0; i < NUM_ROWS; i++) {
842+
expectedBinaryValues[i] = Base64.getDecoder().decode(Base64.getEncoder().encodeToString(_bytesSVValues[i]));
843+
}
844+
testTransformFunction(transformFunction, expectedBinaryValues);
845+
}
822846
}

0 commit comments

Comments
 (0)