Skip to content

Commit 2ad741d

Browse files
authored
[multistage] Add support for the ranking ROW_NUMBER() window function (#10527)
* Add support for the ranking ROW_NUMBER() window function * Compute ROW_NUMBER() window function as long instead of double, rebase and fix tests
1 parent 70c4c5b commit 2ad741d

File tree

6 files changed

+1322
-46
lines changed

6 files changed

+1322
-46
lines changed

pinot-query-planner/src/main/java/org/apache/calcite/rel/rules/PinotWindowExchangeNodeInsertRule.java

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
package org.apache.calcite.rel.rules;
2020

2121
import com.google.common.base.Preconditions;
22+
import com.google.common.collect.ImmutableList;
2223
import com.google.common.collect.ImmutableSet;
2324
import java.util.Collections;
2425
import java.util.HashSet;
@@ -49,7 +50,7 @@ public class PinotWindowExchangeNodeInsertRule extends RelOptRule {
4950
// Supported window functions
5051
// OTHER_FUNCTION supported are: BOOL_AND, BOOL_OR
5152
private static final Set<SqlKind> SUPPORTED_WINDOW_FUNCTION_KIND = ImmutableSet.of(SqlKind.SUM, SqlKind.SUM0,
52-
SqlKind.MIN, SqlKind.MAX, SqlKind.COUNT, SqlKind.OTHER_FUNCTION);
53+
SqlKind.MIN, SqlKind.MAX, SqlKind.COUNT, SqlKind.ROW_NUMBER, SqlKind.OTHER_FUNCTION);
5354

5455
public PinotWindowExchangeNodeInsertRule(RelBuilderFactory factory) {
5556
super(operand(LogicalWindow.class, any()), factory, null);
@@ -145,19 +146,26 @@ private void validateWindowAggCallsSupported(Window.Group windowGroup) {
145146
}
146147

147148
private void validateWindowFrames(Window.Group windowGroup) {
149+
// Has ROWS only aggregation call kind (e.g. ROW_NUMBER)?
150+
boolean isRowsOnlyTypeAggregateCall = isRowsOnlyAggregationCallType(windowGroup.aggCalls);
148151
// For Phase 1 only the default frame is supported
149-
Preconditions.checkState(!windowGroup.isRows, "Default frame must be of type RANGE and not ROWS");
152+
Preconditions.checkState(!windowGroup.isRows || isRowsOnlyTypeAggregateCall,
153+
"Default frame must be of type RANGE and not ROWS unless this is a ROWS only aggregation function");
150154
Preconditions.checkState(windowGroup.lowerBound.isPreceding() && windowGroup.lowerBound.isUnbounded(),
151155
String.format("Lower bound must be UNBOUNDED PRECEDING but it is: %s", windowGroup.lowerBound));
152-
if (windowGroup.orderKeys.getKeys().isEmpty()) {
156+
if (windowGroup.orderKeys.getKeys().isEmpty() && !isRowsOnlyTypeAggregateCall) {
153157
Preconditions.checkState(windowGroup.upperBound.isFollowing() && windowGroup.upperBound.isUnbounded(),
154-
String.format("Upper bound must be UNBOUNDED PRECEDING but it is: %s", windowGroup.upperBound));
158+
String.format("Upper bound must be UNBOUNDED FOLLOWING but it is: %s", windowGroup.upperBound));
155159
} else {
156160
Preconditions.checkState(windowGroup.upperBound.isCurrentRow(),
157161
String.format("Upper bound must be CURRENT ROW but it is: %s", windowGroup.upperBound));
158162
}
159163
}
160164

165+
private boolean isRowsOnlyAggregationCallType(ImmutableList<Window.RexWinAggCall> aggCalls) {
166+
return aggCalls.stream().anyMatch(aggCall -> aggCall.getKind().equals(SqlKind.ROW_NUMBER));
167+
}
168+
161169
private boolean isPartitionByOnlyQuery(Window.Group windowGroup) {
162170
boolean isPartitionByOnly = false;
163171
if (windowGroup.orderKeys.getKeys().isEmpty()) {

pinot-query-planner/src/test/java/org/apache/pinot/query/QueryEnvironmentTestBase.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ protected Object[][] provideQueries() {
109109
new Object[]{"SELECT a.col1, SUM(a.col3) OVER (PARTITION BY a.col2, a.col1) FROM a"},
110110
new Object[]{"SELECT a.col1, SUM(a.col3) OVER (ORDER BY a.col2, a.col1), MIN(a.col3) OVER (ORDER BY a.col2, "
111111
+ "a.col1) FROM a"},
112+
new Object[]{"SELECT a.col1, ROW_NUMBER() OVER(PARTITION BY a.col2 ORDER BY a.col3) FROM a"},
112113
new Object[]{"SELECT a.col1, SUM(a.col3) OVER (ORDER BY a.col2), MIN(a.col3) OVER (ORDER BY a.col2) FROM a"},
113114
new Object[]{"SELECT /*+ skipLeafStageGroupByAggregation */ a.col1, SUM(a.col3) FROM a WHERE a.col3 >= 0"
114115
+ " AND a.col2 = 'a' GROUP BY a.col1"},

0 commit comments

Comments
 (0)