Skip to content

Commit 8de0a62

Browse files
committed
Adding tuple sketch scalar functions
Adding more sketch integration test
1 parent b25b62a commit 8de0a62

File tree

3 files changed

+357
-0
lines changed

3 files changed

+357
-0
lines changed

pinot-core/src/main/java/org/apache/pinot/core/function/scalar/SketchFunctions.java

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
import org.apache.datasketches.theta.UpdateSketch;
3333
import org.apache.datasketches.tuple.aninteger.IntegerSketch;
3434
import org.apache.datasketches.tuple.aninteger.IntegerSummary;
35+
import org.apache.datasketches.tuple.aninteger.IntegerSummarySetOperations;
3536
import org.apache.pinot.core.common.ObjectSerDeUtils;
3637
import org.apache.pinot.spi.annotations.ScalarFunction;
3738
import org.apache.pinot.spi.utils.CommonConstants;
@@ -277,4 +278,95 @@ private static Sketch asThetaSketch(Object sketchObj) {
277278
+ sketchObj.getClass());
278279
}
279280
}
281+
282+
@ScalarFunction(names = {"intSumTupleSketchUnion", "int_sum_tuple_sketch_union"})
283+
public static byte[] intSumTupleSketchUnion(Object o1, Object o2) {
284+
return intSumTupleSketchUnion((int) Math.pow(2, CommonConstants.Helix.DEFAULT_TUPLE_SKETCH_LGK), o1, o2);
285+
}
286+
287+
@ScalarFunction(names = {"intSumTupleSketchUnion", "int_sum_tuple_sketch_union"})
288+
public static byte[] intSumTupleSketchUnion(int nomEntries, Object o1, Object o2) {
289+
return intTupleSketchUnionVar(IntegerSummary.Mode.Sum, nomEntries, o1, o2);
290+
}
291+
292+
@ScalarFunction(names = {"intMinTupleSketchUnion", "int_min_tuple_sketch_union"})
293+
public static byte[] intMinTupleSketchUnion(Object o1, Object o2) {
294+
return intMinTupleSketchUnion((int) Math.pow(2, CommonConstants.Helix.DEFAULT_TUPLE_SKETCH_LGK), o1, o2);
295+
}
296+
297+
@ScalarFunction(names = {"intMinTupleSketchUnion", "int_min_tuple_sketch_union"})
298+
public static byte[] intMinTupleSketchUnion(int nomEntries, Object o1, Object o2) {
299+
return intTupleSketchUnionVar(IntegerSummary.Mode.Min, nomEntries, o1, o2);
300+
}
301+
302+
@ScalarFunction(names = {"intMaxTupleSketchUnion", "int_max_tuple_sketch_union"})
303+
public static byte[] intMaxTupleSketchUnion(Object o1, Object o2) {
304+
return intMaxTupleSketchUnion((int) Math.pow(2, CommonConstants.Helix.DEFAULT_TUPLE_SKETCH_LGK), o1, o2);
305+
}
306+
307+
@ScalarFunction(names = {"intMaxTupleSketchUnion", "int_max_tuple_sketch_union"})
308+
public static byte[] intMaxTupleSketchUnion(int nomEntries, Object o1, Object o2) {
309+
return intTupleSketchUnionVar(IntegerSummary.Mode.Max, nomEntries, o1, o2);
310+
}
311+
312+
private static byte[] intTupleSketchUnionVar(IntegerSummary.Mode mode, int nomEntries, Object... sketchObjects) {
313+
org.apache.datasketches.tuple.Union<IntegerSummary>
314+
union = new org.apache.datasketches.tuple.Union<>(nomEntries,
315+
new IntegerSummarySetOperations(mode, mode));
316+
for (Object sketchObj : sketchObjects) {
317+
union.union(asIntegerSketch(sketchObj));
318+
}
319+
return ObjectSerDeUtils.DATA_SKETCH_INT_TUPLE_SER_DE.serialize(union.getResult().compact());
320+
}
321+
322+
@ScalarFunction(names = {"intSumTupleSketchIntersect", "int_sum_tuple_sketch_intersect"})
323+
public static byte[] intSumTupleSketchIntersect(Object o1, Object o2) {
324+
return intTupleSketchIntersectVar(IntegerSummary.Mode.Sum, o1, o2);
325+
}
326+
327+
@ScalarFunction(names = {"intMinTupleSketchIntersect", "int_min_tuple_sketch_intersect"})
328+
public static byte[] intMinTupleSketchIntersect(Object o1, Object o2) {
329+
return intTupleSketchIntersectVar(IntegerSummary.Mode.Min, o1, o2);
330+
}
331+
332+
@ScalarFunction(names = {"intMaxTupleSketchIntersect", "int_max_tuple_sketch_intersect"})
333+
public static byte[] intMaxTupleSketchIntersect(Object o1, Object o2) {
334+
return intTupleSketchIntersectVar(IntegerSummary.Mode.Max, o1, o2);
335+
}
336+
337+
private static byte[] intTupleSketchIntersectVar(IntegerSummary.Mode mode, Object... sketchObjects) {
338+
org.apache.datasketches.tuple.Intersection<IntegerSummary> intersection =
339+
new org.apache.datasketches.tuple.Intersection<>(new IntegerSummarySetOperations(mode, mode));
340+
for (Object sketchObj : sketchObjects) {
341+
intersection.intersect(asIntegerSketch(sketchObj));
342+
}
343+
return ObjectSerDeUtils.DATA_SKETCH_INT_TUPLE_SER_DE.serialize(intersection.getResult().compact());
344+
}
345+
346+
@ScalarFunction(names = {"intTupleSketchDiff", "int_tuple_sketch_diff"})
347+
public static byte[] intSumTupleSketchDiff(Object o1, Object o2) {
348+
org.apache.datasketches.tuple.AnotB<IntegerSummary> diff = new org.apache.datasketches.tuple.AnotB<>();
349+
diff.setA(asIntegerSketch(o1));
350+
diff.notB(asIntegerSketch(o2));
351+
return ObjectSerDeUtils.DATA_SKETCH_INT_TUPLE_SER_DE.serialize(diff.getResult(false).compact());
352+
}
353+
354+
private static org.apache.datasketches.tuple.Sketch<IntegerSummary> asIntegerSketch(Object sketchObj) {
355+
if (sketchObj instanceof String) {
356+
byte[] decoded = Base64.getDecoder().decode((String) sketchObj);
357+
return ObjectSerDeUtils.DATA_SKETCH_INT_TUPLE_SER_DE.deserialize(decoded);
358+
} else if (sketchObj instanceof org.apache.datasketches.tuple.Sketch) {
359+
return (org.apache.datasketches.tuple.Sketch<IntegerSummary>) sketchObj;
360+
} else if (sketchObj instanceof byte[]) {
361+
return ObjectSerDeUtils.DATA_SKETCH_INT_TUPLE_SER_DE.deserialize((byte[]) sketchObj);
362+
} else {
363+
throw new RuntimeException("Exception occurred getting reading Tuple Sketch, unsupported Object type: "
364+
+ sketchObj.getClass());
365+
}
366+
}
367+
368+
@ScalarFunction(names = {"getIntTupleSketchEstimate", "get_int_tuple_sketch_estimate"})
369+
public static long getIntTupleSketchEstimate(Object o1) {
370+
return Math.round(asIntegerSketch(o1).getEstimate());
371+
}
280372
}

pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/custom/ThetaSketchTest.java

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -450,6 +450,73 @@ public void testThetaSketchQueryV2(boolean useMultiStageQueryEngine)
450450
ImmutableMap.of("Female", 50 + 60 + 70 + 110 + 120 + 130, "Male", 80 + 90 + 100 + 140 + 150 + 160);
451451
runAndAssert(query, expected);
452452
}
453+
454+
// union all by gender
455+
{
456+
String query = "select dimValue, distinctCountThetaSketch(thetaSketchCol) from "
457+
+ "( "
458+
+ "SELECT dimValue, thetaSketchCol FROM " + getTableName()
459+
+ " where dimName = 'gender' and dimValue = 'Female' "
460+
+ "UNION ALL "
461+
+ "SELECT dimValue, thetaSketchCol FROM " + getTableName()
462+
+ " where dimName = 'gender' and dimValue = 'Male' "
463+
+ ") "
464+
+ "GROUP BY dimValue";
465+
ImmutableMap<String, Integer> expected =
466+
ImmutableMap.of("Female", 50 + 60 + 70 + 110 + 120 + 130, "Male", 80 + 90 + 100 + 140 + 150 + 160);
467+
runAndAssert(query, expected);
468+
}
469+
470+
// JOIN all by gender
471+
{
472+
String query = "select a.dimValue, distinctCountThetaSketch(b.thetaSketchCol) "
473+
+ "FROM "
474+
+ "(SELECT dimName, dimValue, thetaSketchCol FROM " + getTableName()
475+
+ " where dimName = 'gender' and dimValue = 'Female') a "
476+
+ "JOIN "
477+
+ "(SELECT dimName, dimValue, thetaSketchCol FROM " + getTableName()
478+
+ " where dimName = 'gender' and dimValue = 'Male') b "
479+
+ "ON a.dimName = b.dimName "
480+
+ "GROUP BY a.dimValue";
481+
ImmutableMap<String, Integer> expected =
482+
ImmutableMap.of("Female", 80 + 90 + 100 + 140 + 150 + 160);
483+
runAndAssert(query, expected);
484+
}
485+
{
486+
String query = "select b.dimValue, distinctCountThetaSketch(a.thetaSketchCol) "
487+
+ "FROM "
488+
+ "(SELECT dimName, dimValue, thetaSketchCol FROM " + getTableName()
489+
+ " where dimName = 'gender' and dimValue = 'Female') a "
490+
+ "JOIN "
491+
+ "(SELECT dimName, dimValue, thetaSketchCol FROM " + getTableName()
492+
+ " where dimName = 'gender' and dimValue = 'Male') b "
493+
+ "ON a.dimName = b.dimName "
494+
+ "GROUP BY b.dimValue";
495+
ImmutableMap<String, Integer> expected =
496+
ImmutableMap.of("Male", 50 + 60 + 70 + 110 + 120 + 130);
497+
runAndAssert(query, expected);
498+
}
499+
{
500+
String query = "SELECT "
501+
+ "GET_THETA_SKETCH_ESTIMATE(THETA_SKETCH_INTERSECT("
502+
+ " DISTINCT_COUNT_RAW_THETA_SKETCH(a.thetaSketchCol, ''), "
503+
+ " DISTINCT_COUNT_RAW_THETA_SKETCH(b.thetaSketchCol, ''))), "
504+
+ "GET_THETA_SKETCH_ESTIMATE(THETA_SKETCH_UNION("
505+
+ " DISTINCT_COUNT_RAW_THETA_SKETCH(a.thetaSketchCol, ''), "
506+
+ " DISTINCT_COUNT_RAW_THETA_SKETCH(b.thetaSketchCol, ''))) "
507+
+ "FROM "
508+
+ "(SELECT dimName, dimValue, thetaSketchCol FROM " + getTableName()
509+
+ " where dimName = 'gender' and dimValue = 'Female') a "
510+
+ "JOIN "
511+
+ "(SELECT dimName, dimValue, thetaSketchCol FROM " + getTableName()
512+
+ " where dimName = 'gender' and dimValue = 'Male') b "
513+
+ "ON a.dimName = b.dimName";
514+
JsonNode jsonNode = postQuery(query);
515+
assertEquals(jsonNode.get("resultTable").get("rows").get(0).get(0).longValue(),
516+
0);
517+
assertEquals(jsonNode.get("resultTable").get("rows").get(0).get(1).longValue(),
518+
50 + 60 + 70 + 110 + 120 + 130 + 80 + 90 + 100 + 140 + 150 + 160);
519+
}
453520
}
454521

455522
private void runAndAssert(String query, int expected)

0 commit comments

Comments
 (0)