Skip to content

Commit fba99d3

Browse files
committed
Improve implementation, add tests
1 parent 5ae3aba commit fba99d3

File tree

3 files changed

+123
-5
lines changed

3 files changed

+123
-5
lines changed

src/Analyzer/Resolve/QueryAnalyzer.cpp

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2848,6 +2848,7 @@ ProjectionNames QueryAnalyzer::resolveFunction(QueryTreeNodePtr & node, Identifi
28482848
if (function_in_arguments_nodes.size() != 2)
28492849
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function '{}' expects 2 arguments", function_name);
28502850

2851+
auto & in_first_argument = function_in_arguments_nodes[0];
28512852
auto & in_second_argument = function_in_arguments_nodes[1];
28522853
auto * table_node = in_second_argument->as<TableNode>();
28532854
auto * table_function_node = in_second_argument->as<TableFunctionNode>();
@@ -2911,15 +2912,29 @@ ProjectionNames QueryAnalyzer::resolveFunction(QueryTreeNodePtr & node, Identifi
29112912

29122913
resolveExpressionNode(in_second_argument, scope, false /*allow_lambda_expression*/, true /*allow_table_expression*/);
29132914

2915+
/// Rewrite X in EXPR, where EXPR is a non-const expression, to has(EXPR, X).
29142916
if (auto * non_const_set_candidate = in_second_argument->as<FunctionNode>())
29152917
{
29162918
const auto & candidate_name = non_const_set_candidate->getFunctionName();
2917-
if (candidate_name == "array")
2919+
if (candidate_name == "array" && !isNullableOrLowCardinalityNullable(in_first_argument->getResultType()))
29182920
{
2919-
function_name = "has";
2920-
is_special_function_in = false;
2921-
auto & function_arguments = function_node.getArguments().getNodes();
2922-
std::swap(function_arguments[0], function_arguments[1]);
2921+
bool contains_nullable = false;
2922+
for (const auto & array_elem : non_const_set_candidate->getArguments())
2923+
{
2924+
if (isNullableOrLowCardinalityNullable(array_elem->getResultType()))
2925+
{
2926+
contains_nullable = true;
2927+
break;
2928+
}
2929+
}
2930+
2931+
if (!contains_nullable)
2932+
{
2933+
function_name = "has";
2934+
is_special_function_in = false;
2935+
auto & function_arguments = function_node.getArguments().getNodes();
2936+
std::swap(function_arguments[0], function_arguments[1]);
2937+
}
29232938
}
29242939
}
29252940
}

tests/queries/0_stateless/03173_non_const_in_arg.reference

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,59 @@
77
1
88
6
99
7
10+
-- MORE CASES --
11+
-- { echoOn }
12+
13+
SELECT null in [number % 3, number % 5] FROM numbers(2); -- { serverError UNSUPPORTED_METHOD }
14+
SELECT null in [number % 3, number % 5, null] FROM numbers(2); -- { serverError UNSUPPORTED_METHOD }
15+
SELECT 5 in [number % 3, number % 5, null] FROM numbers(2); -- { serverError UNSUPPORTED_METHOD }
16+
SELECT (1, 2) in [number % 3, number % 5] FROM numbers(2); -- { serverError NO_COMMON_TYPE }
17+
SELECT (1, 2) in (SELECT [0, 0] UNION ALL SELECT [1, 1]); -- { serverError TYPE_MISMATCH }
18+
SELECT (1, 2) in [(number % 3, number % 5)] FROM numbers(2);
19+
0
20+
0
21+
SELECT (1, 2) in (SELECT (0, 0)), (1, 2) in (SELECT (1, 1));
22+
0 0
23+
SELECT (1, 1) in [(number % 3, number % 5)] FROM numbers(2);
24+
0
25+
1
26+
SELECT (1, 1) in (SELECT (0, 0)), (1, 1) in (SELECT (1, 1));
27+
0 1
28+
SELECT (1, null) in [(number % 3, number % 5)] FROM numbers(2);
29+
0
30+
0
31+
SELECT (1, null) in (SELECT (0, 0::Nullable(Int))), (1, null) in (SELECT (1, 1::Nullable(Int)));
32+
0 0
33+
SELECT (1, null) in [(number % 3, number % 5), (1, null)] FROM numbers(2);
34+
1
35+
1
36+
SELECT (1, null) in (SELECT (0, 0::Nullable(Int)) UNION ALL SELECT (1, null)), (1, null) in (SELECT (1, 1::Nullable(Int)) UNION ALL SELECT (1, null));
37+
1 1
38+
SELECT 'ANOTHER SETTING';
39+
ANOTHER SETTING
40+
set transform_null_in = 1;
41+
SELECT null in [number % 3, number % 5] FROM numbers(2); -- { serverError UNSUPPORTED_METHOD }
42+
SELECT null in [number % 3, number % 5, null] FROM numbers(2); -- { serverError UNSUPPORTED_METHOD }
43+
SELECT 5 in [number % 3, number % 5, null] FROM numbers(2); -- { serverError UNSUPPORTED_METHOD }
44+
SELECT (1, 2) in [number % 3, number % 5] FROM numbers(2); -- { serverError NO_COMMON_TYPE }
45+
SELECT (1, 2) in (SELECT [0, 0] UNION ALL SELECT [1, 1]); -- { serverError TYPE_MISMATCH }
46+
SELECT (1, 2) in [(number % 3, number % 5)] FROM numbers(2);
47+
0
48+
0
49+
SELECT (1, 2) in (SELECT (0, 0)), (1, 2) in (SELECT (1, 1));
50+
0 0
51+
SELECT (1, 1) in [(number % 3, number % 5)] FROM numbers(2);
52+
0
53+
1
54+
SELECT (1, 1) in (SELECT (0, 0)), (1, 1) in (SELECT (1, 1));
55+
0 1
56+
SELECT (1, null) in [(number % 3, number % 5)] FROM numbers(2);
57+
0
58+
0
59+
SELECT (1, null) in (SELECT (0, 0::Nullable(Int))), (1, null) in (SELECT (1, 1::Nullable(Int)));
60+
0 0
61+
SELECT (1, null) in [(number % 3, number % 5), (1, null)] FROM numbers(2);
62+
1
63+
1
64+
SELECT (1, null) in (SELECT (0, 0::Nullable(Int)) UNION ALL SELECT (1, null)), (1, null) in (SELECT (1, 1::Nullable(Int)) UNION ALL SELECT (1, null));
65+
1 1

tests/queries/0_stateless/03173_non_const_in_arg.sql

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,50 @@ SELECT number FROM numbers(10) WHERE has([number % 3, number % 5], number % 2) O
22
SELECT '-- IN --';
33
SELECT number FROM numbers(10) WHERE number % 2 IN [number % 3, number % 5] ORDER BY number SETTINGS allow_experimental_analyzer = 1;
44
SELECT number FROM numbers(10) WHERE number % 2 IN [number % 3, number % 5] ORDER BY number SETTINGS allow_experimental_analyzer = 0; -- { serverError UNKNOWN_IDENTIFIER }
5+
6+
SELECT '-- MORE CASES --';
7+
8+
-- { echoOn }
9+
10+
SELECT null in [number % 3, number % 5] FROM numbers(2); -- { serverError UNSUPPORTED_METHOD }
11+
SELECT null in [number % 3, number % 5, null] FROM numbers(2); -- { serverError UNSUPPORTED_METHOD }
12+
SELECT 5 in [number % 3, number % 5, null] FROM numbers(2); -- { serverError UNSUPPORTED_METHOD }
13+
14+
SELECT (1, 2) in [number % 3, number % 5] FROM numbers(2); -- { serverError NO_COMMON_TYPE }
15+
SELECT (1, 2) in (SELECT [0, 0] UNION ALL SELECT [1, 1]); -- { serverError TYPE_MISMATCH }
16+
17+
SELECT (1, 2) in [(number % 3, number % 5)] FROM numbers(2);
18+
SELECT (1, 2) in (SELECT (0, 0)), (1, 2) in (SELECT (1, 1));
19+
20+
SELECT (1, 1) in [(number % 3, number % 5)] FROM numbers(2);
21+
SELECT (1, 1) in (SELECT (0, 0)), (1, 1) in (SELECT (1, 1));
22+
23+
SELECT (1, null) in [(number % 3, number % 5)] FROM numbers(2);
24+
SELECT (1, null) in (SELECT (0, 0::Nullable(Int))), (1, null) in (SELECT (1, 1::Nullable(Int)));
25+
26+
SELECT (1, null) in [(number % 3, number % 5), (1, null)] FROM numbers(2);
27+
SELECT (1, null) in (SELECT (0, 0::Nullable(Int)) UNION ALL SELECT (1, null)), (1, null) in (SELECT (1, 1::Nullable(Int)) UNION ALL SELECT (1, null));
28+
29+
SELECT 'ANOTHER SETTING';
30+
31+
set transform_null_in = 1;
32+
33+
SELECT null in [number % 3, number % 5] FROM numbers(2); -- { serverError UNSUPPORTED_METHOD }
34+
SELECT null in [number % 3, number % 5, null] FROM numbers(2); -- { serverError UNSUPPORTED_METHOD }
35+
SELECT 5 in [number % 3, number % 5, null] FROM numbers(2); -- { serverError UNSUPPORTED_METHOD }
36+
37+
SELECT (1, 2) in [number % 3, number % 5] FROM numbers(2); -- { serverError NO_COMMON_TYPE }
38+
SELECT (1, 2) in (SELECT [0, 0] UNION ALL SELECT [1, 1]); -- { serverError TYPE_MISMATCH }
39+
40+
SELECT (1, 2) in [(number % 3, number % 5)] FROM numbers(2);
41+
SELECT (1, 2) in (SELECT (0, 0)), (1, 2) in (SELECT (1, 1));
42+
43+
SELECT (1, 1) in [(number % 3, number % 5)] FROM numbers(2);
44+
SELECT (1, 1) in (SELECT (0, 0)), (1, 1) in (SELECT (1, 1));
45+
46+
SELECT (1, null) in [(number % 3, number % 5)] FROM numbers(2);
47+
SELECT (1, null) in (SELECT (0, 0::Nullable(Int))), (1, null) in (SELECT (1, 1::Nullable(Int)));
48+
49+
SELECT (1, null) in [(number % 3, number % 5), (1, null)] FROM numbers(2);
50+
SELECT (1, null) in (SELECT (0, 0::Nullable(Int)) UNION ALL SELECT (1, null)), (1, null) in (SELECT (1, 1::Nullable(Int)) UNION ALL SELECT (1, null));
51+

0 commit comments

Comments
 (0)