Skip to content

Commit 92016e8

Browse files
authored
Fix if function with NULLs (#11807)
1 parent 55eee9b commit 92016e8

File tree

8 files changed

+226
-65
lines changed

8 files changed

+226
-65
lines changed

src/Columns/ColumnNullable.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -562,7 +562,6 @@ void ColumnNullable::applyNullMap(const ColumnNullable & other)
562562
applyNullMap(other.getNullMapColumn());
563563
}
564564

565-
566565
void ColumnNullable::checkConsistency() const
567566
{
568567
if (null_map->size() != getNestedColumn().size())

src/Columns/ColumnVector.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,31 @@ ColumnPtr ColumnVector<T>::filter(const IColumn::Filter & filt, ssize_t result_s
408408
return res;
409409
}
410410

411+
template <typename T>
412+
void ColumnVector<T>::applyZeroMap(const IColumn::Filter & filt, bool inverted)
413+
{
414+
size_t size = data.size();
415+
if (size != filt.size())
416+
throw Exception("Size of filter doesn't match size of column.", ErrorCodes::SIZES_OF_COLUMNS_DOESNT_MATCH);
417+
418+
const UInt8 * filt_pos = filt.data();
419+
const UInt8 * filt_end = filt_pos + size;
420+
T * data_pos = data.data();
421+
422+
if (inverted)
423+
{
424+
for (; filt_pos < filt_end; ++filt_pos, ++data_pos)
425+
if (!*filt_pos)
426+
*data_pos = 0;
427+
}
428+
else
429+
{
430+
for (; filt_pos < filt_end; ++filt_pos, ++data_pos)
431+
if (*filt_pos)
432+
*data_pos = 0;
433+
}
434+
}
435+
411436
template <typename T>
412437
ColumnPtr ColumnVector<T>::permute(const IColumn::Permutation & perm, size_t limit) const
413438
{

src/Columns/ColumnVector.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,9 @@ class ColumnVector final : public COWHelper<ColumnVectorHelper, ColumnVector<T>>
285285
return typeid(rhs) == typeid(ColumnVector<T>);
286286
}
287287

288+
/// Replace elements that match the filter with zeroes. If inverted replaces not matched elements.
289+
void applyZeroMap(const IColumn::Filter & filt, bool inverted = false);
290+
288291
/** More efficient methods of manipulation - to manipulate with data directly. */
289292
Container & getData()
290293
{

src/Functions/if.cpp

Lines changed: 72 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ struct NumIfImpl<A, B, NumberTraits::Error>
158158
private:
159159
[[noreturn]] static void throwError()
160160
{
161-
throw Exception("Internal logic error: invalid types of arguments 2 and 3 of if", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
161+
throw Exception("Invalid types of arguments 2 and 3 of if", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
162162
}
163163
public:
164164
template <typename... Args> static void vectorVector(Args &&...) { throwError(); }
@@ -656,30 +656,89 @@ class FunctionIf : public FunctionIfBase</*null_is_false=*/false>
656656
block.getByPosition(result).column = std::move(result_column);
657657
}
658658

659-
bool executeForNullableCondition(Block & block, const ColumnNumbers & arguments, size_t result, size_t /*input_rows_count*/)
659+
bool executeForConstAndNullableCondition(Block & block, const ColumnNumbers & arguments, size_t result, size_t /*input_rows_count*/)
660660
{
661661
const ColumnWithTypeAndName & arg_cond = block.getByPosition(arguments[0]);
662662
bool cond_is_null = arg_cond.column->onlyNull();
663663

664-
if (cond_is_null)
664+
ColumnPtr not_const_condition = arg_cond.column;
665+
bool cond_is_const = false;
666+
bool cond_is_true = false;
667+
bool cond_is_false = false;
668+
if (const auto * const_arg = checkAndGetColumn<ColumnConst>(*arg_cond.column))
665669
{
666-
block.getByPosition(result).column = std::move(block.getByPosition(arguments[2]).column);
667-
return true;
670+
cond_is_const = true;
671+
not_const_condition = const_arg->getDataColumnPtr();
672+
ColumnPtr data_column = const_arg->getDataColumnPtr();
673+
if (const auto * const_nullable_arg = checkAndGetColumn<ColumnNullable>(*data_column))
674+
{
675+
data_column = const_nullable_arg->getNestedColumnPtr();
676+
if (!data_column->empty())
677+
cond_is_null = const_nullable_arg->getNullMapData()[0];
678+
}
679+
680+
if (!data_column->empty())
681+
{
682+
cond_is_true = !cond_is_null && checkAndGetColumn<ColumnUInt8>(*data_column)->getBool(0);
683+
cond_is_false = !cond_is_null && !cond_is_true;
684+
}
668685
}
669686

670-
if (const auto * nullable = checkAndGetColumn<ColumnNullable>(*arg_cond.column))
687+
const auto & column1 = block.getByPosition(arguments[1]);
688+
const auto & column2 = block.getByPosition(arguments[2]);
689+
auto & result_column = block.getByPosition(result);
690+
691+
if (cond_is_true)
671692
{
693+
if (result_column.type->equals(*column1.type))
694+
{
695+
result_column.column = std::move(column1.column);
696+
return true;
697+
}
698+
}
699+
else if (cond_is_false || cond_is_null)
700+
{
701+
if (result_column.type->equals(*column2.type))
702+
{
703+
result_column.column = std::move(column2.column);
704+
return true;
705+
}
706+
}
707+
708+
if (const auto * nullable = checkAndGetColumn<ColumnNullable>(*not_const_condition))
709+
{
710+
ColumnPtr new_cond_column = nullable->getNestedColumnPtr();
711+
size_t column_size = arg_cond.column->size();
712+
713+
if (cond_is_null || cond_is_true || cond_is_false) /// Nullable(Nothing) or consts
714+
{
715+
UInt8 value = cond_is_true ? 1 : 0;
716+
new_cond_column = ColumnConst::create(ColumnUInt8::create(1, value), column_size);
717+
}
718+
else if (checkAndGetColumn<ColumnUInt8>(*new_cond_column))
719+
{
720+
auto nested_column_copy = new_cond_column->cloneResized(new_cond_column->size());
721+
typeid_cast<ColumnUInt8 *>(nested_column_copy.get())->applyZeroMap(nullable->getNullMapData());
722+
new_cond_column = std::move(nested_column_copy);
723+
724+
if (cond_is_const)
725+
new_cond_column = ColumnConst::create(new_cond_column, column_size);
726+
}
727+
else
728+
throw Exception("Illegal column " + arg_cond.column->getName() + " of " + getName() + " condition",
729+
ErrorCodes::ILLEGAL_COLUMN);
730+
672731
Block temporary_block
673732
{
674-
{ nullable->getNestedColumnPtr(), removeNullable(arg_cond.type), arg_cond.name },
675-
block.getByPosition(arguments[1]),
676-
block.getByPosition(arguments[2]),
677-
block.getByPosition(result)
733+
{ new_cond_column, removeNullable(arg_cond.type), arg_cond.name },
734+
column1,
735+
column2,
736+
result_column
678737
};
679738

680739
executeImpl(temporary_block, {0, 1, 2}, 3, temporary_block.rows());
681740

682-
block.getByPosition(result).column = std::move(temporary_block.getByPosition(3).column);
741+
result_column.column = std::move(temporary_block.getByPosition(3).column);
683742
return true;
684743
}
685744

@@ -916,7 +975,7 @@ class FunctionIf : public FunctionIfBase</*null_is_false=*/false>
916975

917976
void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t input_rows_count) override
918977
{
919-
if (executeForNullableCondition(block, arguments, result, input_rows_count)
978+
if (executeForConstAndNullableCondition(block, arguments, result, input_rows_count)
920979
|| executeForNullThenElse(block, arguments, result, input_rows_count)
921980
|| executeForNullableThenElse(block, arguments, result, input_rows_count))
922981
return;
@@ -964,10 +1023,7 @@ class FunctionIf : public FunctionIfBase</*null_is_false=*/false>
9641023
using T0 = typename Types::LeftType;
9651024
using T1 = typename Types::RightType;
9661025

967-
if constexpr (IsDecimalNumber<T0> == IsDecimalNumber<T1>)
968-
return executeTyped<T0, T1>(cond_col, block, arguments, result, input_rows_count);
969-
else
970-
throw Exception("Conditional function with Decimal and non Decimal", ErrorCodes::NOT_IMPLEMENTED);
1026+
return executeTyped<T0, T1>(cond_col, block, arguments, result, input_rows_count);
9711027
};
9721028

9731029
TypeIndex left_id = arg_then.type->getTypeId();

tests/queries/0_stateless/00735_conditional.reference

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@ value vs value
88
0 1 1 Int8 UInt32 Int64
99
0 1 1 Int8 Float32 Float32
1010
0 1 1 Int8 Float64 Float64
11+
0 1 1 Int8 Decimal(9, 0) Decimal(9, 0)
12+
0 1 1 Int8 Decimal(18, 0) Decimal(18, 0)
13+
0 1 1 Int8 Decimal(38, 0) Decimal(38, 0)
1114
0 1 1 Int16 Int8 Int16
1215
0 1 1 Int16 Int16 Int16
1316
0 1 1 Int16 Int32 Int32
@@ -17,6 +20,9 @@ value vs value
1720
0 1 1 Int16 UInt32 Int64
1821
0 1 1 Int16 Float32 Float32
1922
0 1 1 Int16 Float64 Float64
23+
0 1 1 Int16 Decimal(9, 0) Decimal(9, 0)
24+
0 1 1 Int16 Decimal(18, 0) Decimal(18, 0)
25+
0 1 1 Int16 Decimal(38, 0) Decimal(38, 0)
2026
0 1 1 Int32 Int8 Int32
2127
0 1 1 Int32 Int16 Int32
2228
0 1 1 Int32 Int32 Int32
@@ -26,13 +32,18 @@ value vs value
2632
0 1 1 Int32 UInt32 Int64
2733
0 1 1 Int32 Float32 Float64
2834
0 1 1 Int32 Float64 Float64
35+
0 1 1 Int32 Decimal(9, 0) Decimal(9, 0)
36+
0 1 1 Int32 Decimal(18, 0) Decimal(18, 0)
37+
0 1 1 Int32 Decimal(38, 0) Decimal(38, 0)
2938
0 1 1 Int64 Int8 Int64
3039
0 1 1 Int64 Int16 Int64
3140
0 1 1 Int64 Int32 Int64
3241
0 1 1 Int64 Int64 Int64
3342
0 1 1 Int64 UInt8 Int64
3443
0 1 1 Int64 UInt16 Int64
3544
0 1 1 Int64 UInt32 Int64
45+
0 1 1 Int64 Decimal(18, 0) Decimal(18, 0)
46+
0 1 1 Int64 Decimal(38, 0) Decimal(38, 0)
3647
0 1 1 UInt8 Int8 Int16
3748
0 1 1 UInt8 Int16 Int16
3849
0 1 1 UInt8 Int32 Int32
@@ -43,6 +54,9 @@ value vs value
4354
0 1 1 UInt8 UInt64 UInt64
4455
0 1 1 UInt8 Float32 Float32
4556
0 1 1 UInt8 Float64 Float64
57+
0 1 1 UInt8 Decimal(9, 0) Decimal(9, 0)
58+
0 1 1 UInt8 Decimal(18, 0) Decimal(18, 0)
59+
0 1 1 UInt8 Decimal(38, 0) Decimal(38, 0)
4660
0 1 1 UInt16 Int8 Int32
4761
0 1 1 UInt16 Int16 Int32
4862
0 1 1 UInt16 Int32 Int32
@@ -53,6 +67,9 @@ value vs value
5367
0 1 1 UInt16 UInt64 UInt64
5468
0 1 1 UInt16 Float32 Float32
5569
0 1 1 UInt16 Float64 Float64
70+
0 1 1 UInt16 Decimal(9, 0) Decimal(9, 0)
71+
0 1 1 UInt16 Decimal(18, 0) Decimal(18, 0)
72+
0 1 1 UInt16 Decimal(38, 0) Decimal(38, 0)
5673
0 1 1 UInt32 Int8 Int64
5774
0 1 1 UInt32 Int16 Int64
5875
0 1 1 UInt32 Int32 Int64
@@ -63,10 +80,13 @@ value vs value
6380
0 1 1 UInt32 UInt64 UInt64
6481
0 1 1 UInt32 Float32 Float64
6582
0 1 1 UInt32 Float64 Float64
83+
0 1 1 UInt32 Decimal(18, 0) Decimal(18, 0)
84+
0 1 1 UInt32 Decimal(38, 0) Decimal(38, 0)
6685
0 1 1 UInt64 UInt8 UInt64
6786
0 1 1 UInt64 UInt16 UInt64
6887
0 1 1 UInt64 UInt32 UInt64
6988
0 1 1 UInt64 UInt64 UInt64
89+
0 1 1 UInt64 Decimal(38, 0) Decimal(38, 0)
7090
0000-00-00 1970-01-02 1970-01-02 Date Date Date
7191
2000-01-01 2000-01-01 00:00:01 2000-01-01 00:00:01 Date DateTime(\'Europe/Moscow\') DateTime
7292
2000-01-01 00:00:00 2000-01-02 2000-01-02 00:00:00 DateTime(\'Europe/Moscow\') Date DateTime

0 commit comments

Comments
 (0)