Skip to content

Commit 5fa4737

Browse files
do not identify function types by throwing exceptions (#8137)
* do not identify function types by throwing exceptions * comments
1 parent 1382d29 commit 5fa4737

File tree

9 files changed

+82
-72
lines changed

9 files changed

+82
-72
lines changed

pinot-common/src/main/java/org/apache/pinot/common/function/FunctionDefinitionRegistry.java

Lines changed: 0 additions & 48 deletions
This file was deleted.

pinot-common/src/main/java/org/apache/pinot/common/function/TransformFunctionType.java

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,13 @@
1818
*/
1919
package org.apache.pinot.common.function;
2020

21+
import java.util.Arrays;
22+
import java.util.Set;
23+
import java.util.stream.Collectors;
24+
import java.util.stream.Stream;
25+
import org.apache.commons.lang.StringUtils;
26+
27+
2128
public enum TransformFunctionType {
2229
// Aggregation functions for single-valued columns
2330
ADD("add"),
@@ -101,12 +108,28 @@ public enum TransformFunctionType {
101108
// Geo indexing
102109
GEOTOH3("geoToH3");
103110

111+
private static final Set<String> NAMES = Arrays.stream(values())
112+
.flatMap(func -> Stream.of(func.getName(), StringUtils.remove(func.getName(), '_').toUpperCase(),
113+
func.getName().toUpperCase(), func.getName().toLowerCase(), func.name(), func.name().toLowerCase()))
114+
.collect(Collectors.toSet());
115+
104116
private final String _name;
105117

106118
TransformFunctionType(String name) {
107119
_name = name;
108120
}
109121

122+
public static boolean isTransformFunction(String functionName) {
123+
if (NAMES.contains(functionName)) {
124+
return true;
125+
}
126+
// scalar functions
127+
if (FunctionRegistry.containsFunction(functionName)) {
128+
return true;
129+
}
130+
return NAMES.contains(StringUtils.remove(functionName, '_').toUpperCase());
131+
}
132+
110133
/**
111134
* Returns the corresponding transform function type for the given function name.
112135
*/
@@ -120,7 +143,7 @@ public static TransformFunctionType getTransformFunctionType(String functionName
120143
}
121144
// Support function name of both jsonExtractScalar and json_extract_scalar
122145
if (upperCaseFunctionName.contains("_")) {
123-
return getTransformFunctionType(upperCaseFunctionName.replace("_", ""));
146+
return getTransformFunctionType(StringUtils.remove(upperCaseFunctionName, '_'));
124147
}
125148
throw new IllegalArgumentException("Invalid transform function name: " + functionName);
126149
}

pinot-common/src/main/java/org/apache/pinot/common/request/context/RequestContextUtils.java

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
import java.util.ArrayList;
2222
import java.util.Collections;
2323
import java.util.List;
24-
import org.apache.pinot.common.function.FunctionDefinitionRegistry;
2524
import org.apache.pinot.common.request.Expression;
2625
import org.apache.pinot.common.request.ExpressionType;
2726
import org.apache.pinot.common.request.FilterOperator;
@@ -122,7 +121,7 @@ public static FunctionContext getFunction(Function thriftFunction) {
122121
Collections.singletonList(ExpressionContext.forIdentifier("*")));
123122
}
124123
FunctionContext.Type functionType =
125-
FunctionDefinitionRegistry.isAggFunc(functionName) ? FunctionContext.Type.AGGREGATION
124+
AggregationFunctionType.isAggregationFunction(functionName) ? FunctionContext.Type.AGGREGATION
126125
: FunctionContext.Type.TRANSFORM;
127126
List<Expression> operands = thriftFunction.getOperands();
128127
if (operands != null) {
@@ -147,7 +146,7 @@ public static FunctionContext getFunction(FunctionCallAstNode astNode) {
147146
Collections.singletonList(ExpressionContext.forIdentifier("*")));
148147
}
149148
FunctionContext.Type functionType =
150-
FunctionDefinitionRegistry.isAggFunc(functionName) ? FunctionContext.Type.AGGREGATION
149+
AggregationFunctionType.isAggregationFunction(functionName) ? FunctionContext.Type.AGGREGATION
151150
: FunctionContext.Type.TRANSFORM;
152151
List<? extends AstNode> children = astNode.getChildren();
153152
if (children != null) {

pinot-common/src/main/java/org/apache/pinot/common/utils/SqlResultComparator.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,11 @@
3838
import org.apache.calcite.sql.parser.SqlParser;
3939
import org.apache.calcite.sql.parser.babel.SqlBabelParserImpl;
4040
import org.apache.calcite.sql.validate.SqlConformanceEnum;
41-
import org.apache.pinot.common.function.FunctionDefinitionRegistry;
4241
import org.apache.pinot.common.request.Expression;
4342
import org.apache.pinot.common.request.ExpressionType;
4443
import org.apache.pinot.common.request.Function;
4544
import org.apache.pinot.common.request.PinotQuery;
45+
import org.apache.pinot.segment.spi.AggregationFunctionType;
4646
import org.apache.pinot.spi.utils.JsonUtils;
4747
import org.apache.pinot.sql.parsers.CalciteSqlParser;
4848
import org.slf4j.Logger;
@@ -509,7 +509,7 @@ private static boolean isSelectionQuery(String query) {
509509
if (expression.getType() == ExpressionType.FUNCTION) {
510510
Function functionCall = expression.getFunctionCall();
511511
String functionName = functionCall.getOperator();
512-
if (FunctionDefinitionRegistry.isAggFunc(functionName)) {
512+
if (AggregationFunctionType.isAggregationFunction(functionName)) {
513513
return false;
514514
}
515515
}

pinot-common/src/main/java/org/apache/pinot/pql/parsers/PinotQuery2BrokerRequestConverter.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
import java.util.Set;
2525
import java.util.TreeSet;
2626
import org.apache.calcite.sql.SqlKind;
27-
import org.apache.pinot.common.function.FunctionDefinitionRegistry;
2827
import org.apache.pinot.common.request.AggregationInfo;
2928
import org.apache.pinot.common.request.BrokerRequest;
3029
import org.apache.pinot.common.request.DataSource;
@@ -146,7 +145,7 @@ private void convertSelectList(PinotQuery pinotQuery, BrokerRequest brokerReques
146145
case FUNCTION:
147146
Function functionCall = expression.getFunctionCall();
148147
String functionName = functionCall.getOperator();
149-
if (FunctionDefinitionRegistry.isAggFunc(functionName)) {
148+
if (AggregationFunctionType.isAggregationFunction(functionName)) {
150149
AggregationInfo aggInfo = buildAggregationInfo(functionCall);
151150
if (aggregationInfoList == null) {
152151
aggregationInfoList = new ArrayList<>();

pinot-common/src/main/java/org/apache/pinot/pql/parsers/pql2/ast/OutputColumnAstNode.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,14 @@
1818
*/
1919
package org.apache.pinot.pql.parsers.pql2.ast;
2020

21-
import org.apache.pinot.common.function.FunctionDefinitionRegistry;
2221
import org.apache.pinot.common.request.BrokerRequest;
2322
import org.apache.pinot.common.request.Expression;
2423
import org.apache.pinot.common.request.PinotQuery;
2524
import org.apache.pinot.common.request.Selection;
2625
import org.apache.pinot.common.request.transform.TransformExpressionTree;
2726
import org.apache.pinot.common.utils.request.RequestUtils;
2827
import org.apache.pinot.pql.parsers.Pql2CompilationException;
28+
import org.apache.pinot.segment.spi.AggregationFunctionType;
2929

3030

3131
/**
@@ -37,7 +37,7 @@ public void updateBrokerRequest(BrokerRequest brokerRequest) {
3737
for (AstNode astNode : getChildren()) {
3838
if (astNode instanceof FunctionCallAstNode) {
3939
String functionName = ((FunctionCallAstNode) astNode).getName();
40-
if (FunctionDefinitionRegistry.isAggFunc(functionName)) {
40+
if (AggregationFunctionType.isAggregationFunction(functionName)) {
4141
FunctionCallAstNode node = (FunctionCallAstNode) astNode;
4242
brokerRequest.addToAggregationsInfo(node.buildAggregationInfo());
4343
} else {

pinot-common/src/main/java/org/apache/pinot/sql/parsers/CalciteSqlParser.java

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@
4949
import org.apache.calcite.sql.validate.SqlConformanceEnum;
5050
import org.apache.commons.collections.CollectionUtils;
5151
import org.apache.commons.lang3.StringUtils;
52-
import org.apache.pinot.common.function.FunctionDefinitionRegistry;
5352
import org.apache.pinot.common.request.DataSource;
5453
import org.apache.pinot.common.request.Expression;
5554
import org.apache.pinot.common.request.ExpressionType;
@@ -247,10 +246,8 @@ public static boolean isAggregateExpression(Expression expression) {
247246
Function functionCall = expression.getFunctionCall();
248247
if (functionCall != null) {
249248
String operator = functionCall.getOperator();
250-
try {
251-
AggregationFunctionType.getAggregationFunctionType(operator);
249+
if (AggregationFunctionType.isAggregationFunction(operator)) {
252250
return true;
253-
} catch (IllegalArgumentException e) {
254251
}
255252
if (functionCall.getOperandsSize() > 0) {
256253
for (Expression operand : functionCall.getOperands()) {
@@ -593,7 +590,7 @@ private static Expression convertDistinctAndSelectListToFunctionExpression(SqlNo
593590
} else if (columnExpression.getType() == ExpressionType.FUNCTION) {
594591
Function functionCall = columnExpression.getFunctionCall();
595592
String function = functionCall.getOperator();
596-
if (FunctionDefinitionRegistry.isAggFunc(function)) {
593+
if (AggregationFunctionType.isAggregationFunction(function)) {
597594
throw new SqlCompilationException(
598595
"Syntax error: Use of DISTINCT with aggregation functions is not supported");
599596
}

pinot-common/src/test/java/org/apache/pinot/common/function/FunctionDefinitionRegistryTest.java

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
*/
1919
package org.apache.pinot.common.function;
2020

21+
import org.apache.pinot.segment.spi.AggregationFunctionType;
2122
import org.testng.annotations.Test;
2223

2324
import static org.testng.Assert.assertFalse;
@@ -28,13 +29,30 @@ public class FunctionDefinitionRegistryTest {
2829

2930
@Test
3031
public void testIsAggFunc() {
31-
assertTrue(FunctionDefinitionRegistry.isAggFunc("count"));
32-
assertFalse(FunctionDefinitionRegistry.isAggFunc("toEpochSeconds"));
32+
assertTrue(AggregationFunctionType.isAggregationFunction("count"));
33+
assertTrue(AggregationFunctionType.isAggregationFunction("percentileRawEstMV"));
34+
assertTrue(AggregationFunctionType.isAggregationFunction("PERCENTILERAWESTMV"));
35+
assertTrue(AggregationFunctionType.isAggregationFunction("percentilerawestmv"));
36+
assertTrue(AggregationFunctionType.isAggregationFunction("percentile_raw_est_mv"));
37+
assertTrue(AggregationFunctionType.isAggregationFunction("PERCENTILE_RAW_EST_MV"));
38+
assertTrue(AggregationFunctionType.isAggregationFunction("PERCENTILEEST90"));
39+
assertTrue(AggregationFunctionType.isAggregationFunction("percentileest90"));
40+
assertFalse(AggregationFunctionType.isAggregationFunction("toEpochSeconds"));
3341
}
3442

3543
@Test
3644
public void testIsTransformFunc() {
37-
assertTrue(FunctionDefinitionRegistry.isTransformFunc("toEpochSeconds"));
38-
assertFalse(FunctionDefinitionRegistry.isTransformFunc("foo_bar"));
45+
assertTrue(TransformFunctionType.isTransformFunction("toEpochSeconds"));
46+
assertTrue(TransformFunctionType.isTransformFunction("json_extract_scalar"));
47+
assertTrue(TransformFunctionType.isTransformFunction("jsonextractscalar"));
48+
assertTrue(TransformFunctionType.isTransformFunction("JSON_EXTRACT_SCALAR"));
49+
assertTrue(TransformFunctionType.isTransformFunction("JSONEXTRACTSCALAR"));
50+
assertTrue(TransformFunctionType.isTransformFunction("jsonExtractScalar"));
51+
assertTrue(TransformFunctionType.isTransformFunction("ST_AsText"));
52+
assertTrue(TransformFunctionType.isTransformFunction("STAsText"));
53+
assertTrue(TransformFunctionType.isTransformFunction("stastext"));
54+
assertTrue(TransformFunctionType.isTransformFunction("ST_ASTEXT"));
55+
assertTrue(TransformFunctionType.isTransformFunction("STASTEXT"));
56+
assertFalse(TransformFunctionType.isTransformFunction("foo_bar"));
3957
}
4058
}

pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/AggregationFunctionType.java

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,11 @@
1818
*/
1919
package org.apache.pinot.segment.spi;
2020

21-
import org.apache.commons.lang3.StringUtils;
21+
import java.util.Arrays;
22+
import java.util.Set;
23+
import java.util.stream.Collectors;
24+
import java.util.stream.Stream;
25+
import org.apache.commons.lang.StringUtils;
2226

2327

2428
/**
@@ -71,6 +75,9 @@ public enum AggregationFunctionType {
7175
PERCENTILERAWTDIGESTMV("percentileRawTDigestMV"),
7276
DISTINCT("distinct");
7377

78+
private static final Set<String> NAMES = Arrays.stream(values()).flatMap(func -> Stream.of(func.name(),
79+
func.getName(), func.getName().toLowerCase())).collect(Collectors.toSet());
80+
7481
private final String _name;
7582

7683
AggregationFunctionType(String name) {
@@ -81,14 +88,29 @@ public String getName() {
8188
return _name;
8289
}
8390

91+
public static boolean isAggregationFunction(String functionName) {
92+
if (NAMES.contains(functionName)) {
93+
return true;
94+
}
95+
if (functionName.regionMatches(true, 0, "percentile", 0, 10)) {
96+
try {
97+
getAggregationFunctionType(functionName);
98+
return true;
99+
} catch (Exception ignore) {
100+
return false;
101+
}
102+
}
103+
String upperCaseFunctionName = StringUtils.remove(functionName, '_').toUpperCase();
104+
return NAMES.contains(upperCaseFunctionName);
105+
}
106+
84107
/**
85108
* Returns the corresponding aggregation function type for the given function name.
86109
* <p>NOTE: Underscores in the function name are ignored.
87110
*/
88111
public static AggregationFunctionType getAggregationFunctionType(String functionName) {
89-
String upperCaseFunctionName = StringUtils.remove(functionName, '_').toUpperCase();
90-
if (upperCaseFunctionName.startsWith("PERCENTILE")) {
91-
String remainingFunctionName = upperCaseFunctionName.substring(10);
112+
if (functionName.regionMatches(true, 0, "percentile", 0, 10)) {
113+
String remainingFunctionName = StringUtils.remove(functionName, '_').substring(10).toUpperCase();
92114
if (remainingFunctionName.isEmpty() || remainingFunctionName.matches("\\d+")) {
93115
return PERCENTILE;
94116
} else if (remainingFunctionName.equals("EST") || remainingFunctionName.matches("EST\\d+")) {
@@ -114,7 +136,7 @@ public static AggregationFunctionType getAggregationFunctionType(String function
114136
}
115137
} else {
116138
try {
117-
return AggregationFunctionType.valueOf(upperCaseFunctionName);
139+
return AggregationFunctionType.valueOf(StringUtils.remove(functionName, '_').toUpperCase());
118140
} catch (IllegalArgumentException e) {
119141
throw new IllegalArgumentException("Invalid aggregation function name: " + functionName);
120142
}

0 commit comments

Comments
 (0)