Skip to content

Commit c178bb2

Browse files
committed
Adding vector scalar functions
1 parent 834c970 commit c178bb2

File tree

8 files changed

+818
-3
lines changed

8 files changed

+818
-3
lines changed

pinot-common/src/main/java/org/apache/pinot/common/function/TransformFunctionType.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,22 @@ public enum TransformFunctionType {
205205
OperandTypes.family(ImmutableList.of(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC, SqlTypeFamily.NUMERIC),
206206
ordinal -> ordinal > 1 && ordinal < 4)),
207207

208+
// Vector functions
209+
// TODO: Once VECTOR type is defined, we should update here.
210+
COSINE_DISTANCE("cosineDistance", ReturnTypes.explicit(SqlTypeName.DOUBLE),
211+
OperandTypes.family(ImmutableList.of(SqlTypeFamily.ARRAY, SqlTypeFamily.ARRAY, SqlTypeFamily.NUMERIC),
212+
ordinal -> ordinal > 1 && ordinal < 4), "cosine_distance"),
213+
INNER_PRODUCT("innerProduct", ReturnTypes.explicit(SqlTypeName.DOUBLE),
214+
OperandTypes.family(ImmutableList.of(SqlTypeFamily.ARRAY, SqlTypeFamily.ARRAY)), "inner_product"),
215+
L1_DISTANCE("l1Distance", ReturnTypes.explicit(SqlTypeName.DOUBLE),
216+
OperandTypes.family(ImmutableList.of(SqlTypeFamily.ARRAY, SqlTypeFamily.ARRAY)), "l1_distance"),
217+
L2_DISTANCE("l2Distance", ReturnTypes.explicit(SqlTypeName.DOUBLE),
218+
OperandTypes.family(ImmutableList.of(SqlTypeFamily.ARRAY, SqlTypeFamily.ARRAY)), "l2_distance"),
219+
VECTOR_DIMS("vectorDims", ReturnTypes.explicit(SqlTypeName.INTEGER),
220+
OperandTypes.family(ImmutableList.of(SqlTypeFamily.ARRAY)), "vector_dims"),
221+
VECTOR_NORM("vectorNorm", ReturnTypes.explicit(SqlTypeName.DOUBLE),
222+
OperandTypes.family(ImmutableList.of(SqlTypeFamily.ARRAY)), "vector_norm"),
223+
208224
// Trigonometry
209225
SIN("sin"),
210226
COS("cos"),
Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
/**
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
package org.apache.pinot.common.function.scalar;
20+
21+
import com.google.common.base.Preconditions;
22+
import org.apache.pinot.spi.annotations.ScalarFunction;
23+
24+
25+
/**
26+
* Inbuilt Vector Transformation Functions
27+
* The functions can be used as UDFs in Query when added in the FunctionRegistry.
28+
* @ScalarFunction annotation is used with each method for the registration
29+
*
30+
* Example usage:
31+
*/
32+
public class VectorFunctions {
33+
private VectorFunctions() {
34+
}
35+
36+
/**
37+
* Returns the cosine distance between two vectors
38+
* @param vector1 vector1
39+
* @param vector2 vector2
40+
* @return cosine distance
41+
*/
42+
@ScalarFunction(names = {"cosinedistance", "cosine_distance"})
43+
public static double cosineDistance(float[] vector1, float[] vector2) {
44+
return cosineDistance(vector1, vector2, Double.NaN);
45+
}
46+
47+
/**
48+
* Returns the cosine distance between two vectors, with a default value if the norm of either vector is 0.
49+
* @param vector1 vector1
50+
* @param vector2 vector2
51+
* @param defaultValue default value when either vector has a norm of 0
52+
* @return cosine distance
53+
*/
54+
@ScalarFunction(names = {"cosinedistance", "cosine_distance"})
55+
public static double cosineDistance(float[] vector1, float[] vector2, double defaultValue) {
56+
validateVectors(vector1, vector2);
57+
double dotProduct = 0.0;
58+
double norm1 = 0.0;
59+
double norm2 = 0.0;
60+
for (int i = 0; i < vector1.length; i++) {
61+
dotProduct += vector1[i] * vector2[i];
62+
norm1 += Math.pow(vector1[i], 2);
63+
norm2 += Math.pow(vector2[i], 2);
64+
}
65+
if (norm1 == 0 || norm2 == 0) {
66+
return defaultValue;
67+
}
68+
return 1 - (dotProduct / (Math.sqrt(norm1) * Math.sqrt(norm2)));
69+
}
70+
71+
/**
72+
* Returns the inner product between two vectors
73+
* @param vector1 vector1
74+
* @param vector2 vector2
75+
* @return inner product
76+
*/
77+
@ScalarFunction(names = {"innerproduct", "inner_product"})
78+
public static double innerProduct(float[] vector1, float[] vector2) {
79+
validateVectors(vector1, vector2);
80+
double dotProduct = 0.0;
81+
for (int i = 0; i < vector1.length; i++) {
82+
dotProduct += vector1[i] * vector2[i];
83+
}
84+
return dotProduct;
85+
}
86+
87+
/**
88+
* Returns the L2 distance between two vectors
89+
* @param vector1 vector1
90+
* @param vector2 vector2
91+
* @return L2 distance
92+
*/
93+
@ScalarFunction(names = {"l2distance", "l2_distance"})
94+
public static double l2Distance(float[] vector1, float[] vector2) {
95+
validateVectors(vector1, vector2);
96+
double distance = 0.0;
97+
for (int i = 0; i < vector1.length; i++) {
98+
distance += Math.pow(vector1[i] - vector2[i], 2);
99+
}
100+
return Math.sqrt(distance);
101+
}
102+
103+
/**
104+
* Returns the L1 distance between two vectors
105+
* @param vector1 vector1
106+
* @param vector2 vector2
107+
* @return L1 distance
108+
*/
109+
@ScalarFunction(names = {"l1distance", "l1_distance"})
110+
public static double l1Distance(float[] vector1, float[] vector2) {
111+
validateVectors(vector1, vector2);
112+
double distance = 0.0;
113+
for (int i = 0; i < vector1.length; i++) {
114+
distance += Math.abs(vector1[i] - vector2[i]);
115+
}
116+
return distance;
117+
}
118+
119+
/**
120+
* Returns the number of dimensions in a vector
121+
* @param vector input vector
122+
* @return number of dimensions
123+
*/
124+
@ScalarFunction(names = {"vectordims", "vector_dims"})
125+
public static int vectorDims(float[] vector) {
126+
validateVector(vector);
127+
return vector.length;
128+
}
129+
130+
/**
131+
* Returns the norm of a vector
132+
* @param vector input vector
133+
* @return norm
134+
*/
135+
@ScalarFunction(names = {"vectornorm", "vector_norm"})
136+
public static double vectorNorm(float[] vector) {
137+
validateVector(vector);
138+
double norm = 0.0;
139+
for (int i = 0; i < vector.length; i++) {
140+
norm += Math.pow(vector[i], 2);
141+
}
142+
return Math.sqrt(norm);
143+
}
144+
145+
public static void validateVectors(float[] vector1, float[] vector2) {
146+
Preconditions.checkArgument(vector1 != null && vector2 != null, "Null vector passed");
147+
Preconditions.checkArgument(vector1.length == vector2.length, "Vector lengths do not match");
148+
}
149+
150+
public static void validateVector(float[] vector) {
151+
Preconditions.checkArgument(vector != null, "Null vector passed");
152+
Preconditions.checkArgument(vector.length > 0, "Empty vector passed");
153+
}
154+
}

pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/TransformFunctionFactory.java

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,12 @@
7070
import org.apache.pinot.core.operator.transform.function.TrigonometricTransformFunctions.SinhTransformFunction;
7171
import org.apache.pinot.core.operator.transform.function.TrigonometricTransformFunctions.TanTransformFunction;
7272
import org.apache.pinot.core.operator.transform.function.TrigonometricTransformFunctions.TanhTransformFunction;
73+
import org.apache.pinot.core.operator.transform.function.VectorTransformFunctions.CosineDistanceTransformFunction;
74+
import org.apache.pinot.core.operator.transform.function.VectorTransformFunctions.InnerProductTransformFunction;
75+
import org.apache.pinot.core.operator.transform.function.VectorTransformFunctions.L1DistanceTransformFunction;
76+
import org.apache.pinot.core.operator.transform.function.VectorTransformFunctions.L2DistanceTransformFunction;
77+
import org.apache.pinot.core.operator.transform.function.VectorTransformFunctions.VectorDimsTransformFunction;
78+
import org.apache.pinot.core.operator.transform.function.VectorTransformFunctions.VectorNormTransformFunction;
7379
import org.apache.pinot.core.query.request.context.QueryContext;
7480
import org.apache.pinot.core.query.request.context.utils.QueryContextConverterUtils;
7581
import org.apache.pinot.segment.spi.datasource.DataSource;
@@ -217,6 +223,14 @@ private static Map<String, Class<? extends TransformFunction>> createRegistry()
217223
typeToImplementation.put(TransformFunctionType.DEGREES, DegreesTransformFunction.class);
218224
typeToImplementation.put(TransformFunctionType.RADIANS, RadiansTransformFunction.class);
219225

226+
// Vector functions
227+
typeToImplementation.put(TransformFunctionType.COSINE_DISTANCE, CosineDistanceTransformFunction.class);
228+
typeToImplementation.put(TransformFunctionType.INNER_PRODUCT, InnerProductTransformFunction.class);
229+
typeToImplementation.put(TransformFunctionType.L1_DISTANCE, L1DistanceTransformFunction.class);
230+
typeToImplementation.put(TransformFunctionType.L2_DISTANCE, L2DistanceTransformFunction.class);
231+
typeToImplementation.put(TransformFunctionType.VECTOR_DIMS, VectorDimsTransformFunction.class);
232+
typeToImplementation.put(TransformFunctionType.VECTOR_NORM, VectorNormTransformFunction.class);
233+
220234
Map<String, Class<? extends TransformFunction>> registry = new HashMap<>(typeToImplementation.size());
221235
for (Map.Entry<TransformFunctionType, Class<? extends TransformFunction>> entry : typeToImplementation.entrySet()) {
222236
for (String alias : entry.getKey().getAlternativeNames()) {

0 commit comments

Comments
 (0)