Skip to content

Commit bab4481

Browse files
add scalar function for cast so it can be calculated at compile time
1 parent 900f01f commit bab4481

File tree

6 files changed

+295
-27
lines changed

6 files changed

+295
-27
lines changed

config/checkstyle.xml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@
136136
org.apache.pinot.controller.recommender.rules.io.params.RecommenderConstants.RulesToExecute.*,
137137
org.apache.pinot.controller.recommender.rules.utils.PredicateParseResult.*,
138138
org.apache.pinot.client.utils.Constants.*,
139+
org.apache.pinot.common.utils.PinotDataType.*,
139140
org.apache.pinot.segment.local.startree.StarTreeBuilderUtils.*,
140141
org.apache.pinot.segment.local.startree.v2.store.StarTreeIndexMapUtils.*,
141142
org.apache.pinot.segment.local.utils.GeometryType.*,

pinot-common/src/main/java/org/apache/pinot/common/function/scalar/DataTypeConversionFunctions.java

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,19 @@
1818
*/
1919
package org.apache.pinot.common.function.scalar;
2020

21+
import com.google.common.base.Preconditions;
2122
import java.math.BigDecimal;
2223
import java.util.Base64;
24+
import org.apache.pinot.common.utils.PinotDataType;
2325
import org.apache.pinot.spi.annotations.ScalarFunction;
2426
import org.apache.pinot.spi.utils.BigDecimalUtils;
2527
import org.apache.pinot.spi.utils.BytesUtils;
2628

29+
import static org.apache.pinot.common.utils.PinotDataType.DOUBLE;
30+
import static org.apache.pinot.common.utils.PinotDataType.INTEGER;
31+
import static org.apache.pinot.common.utils.PinotDataType.LONG;
32+
import static org.apache.pinot.common.utils.PinotDataType.STRING;
33+
2734

2835
/**
2936
* Contains function to convert a datatype to another datatype.
@@ -32,6 +39,33 @@ public class DataTypeConversionFunctions {
3239
private DataTypeConversionFunctions() {
3340
}
3441

42+
@ScalarFunction
43+
public static Object cast(Object value, String targetTypeLiteral) {
44+
try {
45+
Class<?> clazz = value.getClass();
46+
Preconditions.checkArgument(!clazz.isArray() | clazz == byte[].class, "%s must not be an array type", clazz);
47+
PinotDataType sourceType = PinotDataType.getSingleValueType(clazz);
48+
String transformed = targetTypeLiteral.toUpperCase();
49+
PinotDataType targetDataType;
50+
if ("INT".equals(transformed)) {
51+
targetDataType = INTEGER;
52+
} else if ("VARCHAR".equals(transformed)) {
53+
targetDataType = STRING;
54+
} else {
55+
targetDataType = PinotDataType.valueOf(transformed);
56+
}
57+
if (sourceType == STRING && (targetDataType == INTEGER || targetDataType == LONG)) {
58+
if (String.valueOf(value).contains(".")) {
59+
// convert integers via double to avoid parse errors
60+
return targetDataType.convert(DOUBLE.convert(value, sourceType), DOUBLE);
61+
}
62+
}
63+
return targetDataType.convert(value, sourceType);
64+
} catch (IllegalArgumentException e) {
65+
throw new IllegalArgumentException("Unknown data type: " + targetTypeLiteral);
66+
}
67+
}
68+
3569
/**
3670
* Converts big decimal string representation to bytes.
3771
* Only scale of upto 2 bytes is supported by the function
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
/**
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
package org.apache.pinot.common.function.scalar;
20+
21+
import org.testng.annotations.DataProvider;
22+
import org.testng.annotations.Test;
23+
24+
import static org.testng.Assert.assertEquals;
25+
26+
27+
public class DataTypeConversionFunctionsTest {
28+
29+
@DataProvider(name = "testCases")
30+
public static Object[][] testCases() {
31+
return new Object[][]{
32+
{"a", "string", "a"},
33+
{"10", "int", 10},
34+
{"10", "long", 10L},
35+
{"10", "float", 10F},
36+
{"10", "double", 10D},
37+
{"10.0", "int", 10},
38+
{"10.0", "long", 10L},
39+
{"10.0", "float", 10F},
40+
{"10.0", "double", 10D},
41+
{10, "string", "10"},
42+
{10L, "string", "10"},
43+
{10F, "string", "10.0"},
44+
{10D, "string", "10.0"},
45+
{"a", "string", "a"},
46+
{10, "int", 10},
47+
{10L, "long", 10L},
48+
{10F, "float", 10F},
49+
{10D, "double", 10D},
50+
{10L, "int", 10},
51+
{10, "long", 10L},
52+
{10D, "float", 10F},
53+
{10F, "double", 10D},
54+
{"abc1", "bytes", new byte[]{(byte) 0xab, (byte) 0xc1}},
55+
{new byte[]{(byte) 0xab, (byte) 0xc1}, "string", "abc1"}
56+
};
57+
}
58+
59+
@Test(dataProvider = "testCases")
60+
public void test(Object value, String type, Object expected) {
61+
assertEquals(DataTypeConversionFunctions.cast(value, type), expected);
62+
}
63+
}

pinot-common/src/test/java/org/apache/pinot/sql/parsers/CalciteSqlCompilerTest.java

Lines changed: 4 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1629,41 +1629,19 @@ public void testReservedKeywords() {
16291629
public void testCastTransformation() {
16301630
PinotQuery pinotQuery = CalciteSqlParser.compileToPinotQuery("select CAST(25.65 AS int) from myTable");
16311631
Assert.assertEquals(pinotQuery.getSelectListSize(), 1);
1632-
Assert.assertEquals(pinotQuery.getSelectList().get(0).getFunctionCall().getOperator(), "cast");
1633-
Assert.assertEquals(
1634-
pinotQuery.getSelectList().get(0).getFunctionCall().getOperands().get(0).getLiteral().getDoubleValue(), 25.65);
1635-
Assert.assertEquals(
1636-
pinotQuery.getSelectList().get(0).getFunctionCall().getOperands().get(1).getLiteral().getStringValue(),
1637-
"INTEGER");
1632+
Assert.assertEquals(pinotQuery.getSelectList().get(0).getLiteral().getLongValue(), 25);
16381633

16391634
pinotQuery = CalciteSqlParser.compileToPinotQuery("SELECT CAST('20170825' AS LONG) from myTable");
16401635
Assert.assertEquals(pinotQuery.getSelectListSize(), 1);
1641-
Assert.assertEquals(pinotQuery.getSelectList().get(0).getFunctionCall().getOperator(), "cast");
1642-
Assert.assertEquals(
1643-
pinotQuery.getSelectList().get(0).getFunctionCall().getOperands().get(0).getLiteral().getStringValue(),
1644-
"20170825");
1645-
Assert.assertEquals(
1646-
pinotQuery.getSelectList().get(0).getFunctionCall().getOperands().get(1).getLiteral().getStringValue(), "LONG");
1636+
Assert.assertEquals(pinotQuery.getSelectList().get(0).getLiteral().getLongValue(), 20170825);
16471637

16481638
pinotQuery = CalciteSqlParser.compileToPinotQuery("SELECT CAST(20170825.0 AS Float) from myTable");
16491639
Assert.assertEquals(pinotQuery.getSelectListSize(), 1);
1650-
Assert.assertEquals(pinotQuery.getSelectList().get(0).getFunctionCall().getOperator(), "cast");
1651-
Assert.assertEquals(
1652-
pinotQuery.getSelectList().get(0).getFunctionCall().getOperands().get(0).getLiteral().getDoubleValue(),
1653-
20170825.0);
1654-
Assert.assertEquals(
1655-
pinotQuery.getSelectList().get(0).getFunctionCall().getOperands().get(1).getLiteral().getStringValue(),
1656-
"FLOAT");
1640+
Assert.assertEquals((float) pinotQuery.getSelectList().get(0).getLiteral().getDoubleValue(), 20170825.0F);
16571641

16581642
pinotQuery = CalciteSqlParser.compileToPinotQuery("SELECT CAST(20170825.0 AS dOuble) from myTable");
16591643
Assert.assertEquals(pinotQuery.getSelectListSize(), 1);
1660-
Assert.assertEquals(pinotQuery.getSelectList().get(0).getFunctionCall().getOperator(), "cast");
1661-
Assert.assertEquals(
1662-
pinotQuery.getSelectList().get(0).getFunctionCall().getOperands().get(0).getLiteral().getDoubleValue(),
1663-
20170825.0);
1664-
Assert.assertEquals(
1665-
pinotQuery.getSelectList().get(0).getFunctionCall().getOperands().get(1).getLiteral().getStringValue(),
1666-
"DOUBLE");
1644+
Assert.assertEquals((float) pinotQuery.getSelectList().get(0).getLiteral().getDoubleValue(), 20170825.0F);
16671645

16681646
pinotQuery = CalciteSqlParser.compileToPinotQuery("SELECT CAST(column1 AS STRING) from myTable");
16691647
Assert.assertEquals(pinotQuery.getSelectListSize(), 1);

pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/CastTransformFunctionTest.java

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@
2323
import org.testng.Assert;
2424
import org.testng.annotations.Test;
2525

26+
import static org.apache.pinot.common.function.scalar.DataTypeConversionFunctions.cast;
27+
import static org.testng.Assert.assertEquals;
28+
2629

2730
public class CastTransformFunctionTest extends BaseTransformFunctionTest {
2831

@@ -32,52 +35,70 @@ public void testCastTransformFunction() {
3235
RequestContextUtils.getExpressionFromSQL(String.format("CAST(%s AS string)", INT_SV_COLUMN));
3336
TransformFunction transformFunction = TransformFunctionFactory.get(expression, _dataSourceMap);
3437
Assert.assertTrue(transformFunction instanceof CastTransformFunction);
35-
Assert.assertEquals(transformFunction.getName(), CastTransformFunction.FUNCTION_NAME);
38+
assertEquals(transformFunction.getName(), CastTransformFunction.FUNCTION_NAME);
3639
String[] expectedValues = new String[NUM_ROWS];
40+
String[] scalarStringValues = new String[NUM_ROWS];
3741
for (int i = 0; i < NUM_ROWS; i++) {
3842
expectedValues[i] = Integer.toString(_intSVValues[i]);
43+
scalarStringValues[i] = (String) cast(_intSVValues[i], "string");
3944
}
4045
testTransformFunction(transformFunction, expectedValues);
46+
assertEquals(expectedValues, scalarStringValues);
4147

4248
expression =
4349
RequestContextUtils.getExpressionFromSQL(String.format("CAST(CAST(%s as INT) as FLOAT)", FLOAT_SV_COLUMN));
4450
transformFunction = TransformFunctionFactory.get(expression, _dataSourceMap);
4551
Assert.assertTrue(transformFunction instanceof CastTransformFunction);
4652
float[] expectedFloatValues = new float[NUM_ROWS];
53+
float[] scalarFloatValues = new float[NUM_ROWS];
4754
for (int i = 0; i < NUM_ROWS; i++) {
4855
expectedFloatValues[i] = (int) _floatSVValues[i];
56+
scalarFloatValues[i] = (float) cast(cast(_floatSVValues[i], "int"), "float");
4957
}
5058
testTransformFunction(transformFunction, expectedFloatValues);
59+
assertEquals(expectedFloatValues, scalarFloatValues);
5160

5261
expression = RequestContextUtils.getExpressionFromSQL(
5362
String.format("CAST(ADD(CAST(%s AS LONG), %s) AS STRING)", DOUBLE_SV_COLUMN, LONG_SV_COLUMN));
5463
transformFunction = TransformFunctionFactory.get(expression, _dataSourceMap);
5564
Assert.assertTrue(transformFunction instanceof CastTransformFunction);
5665
for (int i = 0; i < NUM_ROWS; i++) {
5766
expectedValues[i] = Double.toString((double) (long) _doubleSVValues[i] + (double) _longSVValues[i]);
67+
scalarStringValues[i] = (String) cast(
68+
(double) (long) cast(_doubleSVValues[i], "long") + (double) _longSVValues[i], "string");
5869
}
5970
testTransformFunction(transformFunction, expectedValues);
71+
assertEquals(expectedValues, scalarStringValues);
6072

6173
expression = RequestContextUtils.getExpressionFromSQL(
6274
String.format("caSt(cAst(casT(%s as inT) + %s aS sTring) As DouBle)", FLOAT_SV_COLUMN, INT_SV_COLUMN));
6375
transformFunction = TransformFunctionFactory.get(expression, _dataSourceMap);
6476
Assert.assertTrue(transformFunction instanceof CastTransformFunction);
6577
double[] expectedDoubleValues = new double[NUM_ROWS];
78+
double[] scalarDoubleValues = new double[NUM_ROWS];
6679
for (int i = 0; i < NUM_ROWS; i++) {
6780
expectedDoubleValues[i] = (double) (int) _floatSVValues[i] + (double) _intSVValues[i];
81+
scalarDoubleValues[i] =
82+
(double) cast(cast((double) (int) cast(_floatSVValues[i], "int") + (double) _intSVValues[i], "string"),
83+
"double");
6884
}
6985
testTransformFunction(transformFunction, expectedDoubleValues);
86+
assertEquals(expectedDoubleValues, scalarDoubleValues);
7087

7188
expression = RequestContextUtils.getExpressionFromSQL(String
7289
.format("CAST(CAST(%s AS INT) - CAST(%s AS FLOAT) / CAST(%s AS DOUBLE) AS LONG)", DOUBLE_SV_COLUMN,
7390
LONG_SV_COLUMN, INT_SV_COLUMN));
7491
transformFunction = TransformFunctionFactory.get(expression, _dataSourceMap);
7592
Assert.assertTrue(transformFunction instanceof CastTransformFunction);
7693
long[] expectedLongValues = new long[NUM_ROWS];
94+
long[] longScalarValues = new long[NUM_ROWS];
7795
for (int i = 0; i < NUM_ROWS; i++) {
7896
expectedLongValues[i] =
7997
(long) ((double) (int) _doubleSVValues[i] - (double) (float) _longSVValues[i] / (double) _intSVValues[i]);
98+
longScalarValues[i] = (long) cast((double) (int) cast(_doubleSVValues[i], "int")
99+
- (double) (float) cast(_longSVValues[i], "float") / (double) cast(_intSVValues[i], "double"), "long");
80100
}
81101
testTransformFunction(transformFunction, expectedLongValues);
102+
assertEquals(expectedLongValues, longScalarValues);
82103
}
83104
}

0 commit comments

Comments
 (0)