@@ -158,7 +158,7 @@ struct NumIfImpl<A, B, NumberTraits::Error>
158158private:
159159 [[noreturn]] static void throwError ()
160160 {
161- throw Exception (" Internal logic error: invalid types of arguments 2 and 3 of if" , ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
161+ throw Exception (" Invalid types of arguments 2 and 3 of if" , ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
162162 }
163163public:
164164 template <typename ... Args> static void vectorVector (Args &&...) { throwError (); }
@@ -656,30 +656,89 @@ class FunctionIf : public FunctionIfBase</*null_is_false=*/false>
656656 block.getByPosition (result).column = std::move (result_column);
657657 }
658658
659- bool executeForNullableCondition (Block & block, const ColumnNumbers & arguments, size_t result, size_t /* input_rows_count*/ )
659+ bool executeForConstAndNullableCondition (Block & block, const ColumnNumbers & arguments, size_t result, size_t /* input_rows_count*/ )
660660 {
661661 const ColumnWithTypeAndName & arg_cond = block.getByPosition (arguments[0 ]);
662662 bool cond_is_null = arg_cond.column ->onlyNull ();
663663
664- if (cond_is_null)
664+ ColumnPtr not_const_condition = arg_cond.column ;
665+ bool cond_is_const = false ;
666+ bool cond_is_true = false ;
667+ bool cond_is_false = false ;
668+ if (const auto * const_arg = checkAndGetColumn<ColumnConst>(*arg_cond.column ))
665669 {
666- block.getByPosition (result).column = std::move (block.getByPosition (arguments[2 ]).column );
667- return true ;
670+ cond_is_const = true ;
671+ not_const_condition = const_arg->getDataColumnPtr ();
672+ ColumnPtr data_column = const_arg->getDataColumnPtr ();
673+ if (const auto * const_nullable_arg = checkAndGetColumn<ColumnNullable>(*data_column))
674+ {
675+ data_column = const_nullable_arg->getNestedColumnPtr ();
676+ if (!data_column->empty ())
677+ cond_is_null = const_nullable_arg->getNullMapData ()[0 ];
678+ }
679+
680+ if (!data_column->empty ())
681+ {
682+ cond_is_true = !cond_is_null && checkAndGetColumn<ColumnUInt8>(*data_column)->getBool (0 );
683+ cond_is_false = !cond_is_null && !cond_is_true;
684+ }
668685 }
669686
670- if (const auto * nullable = checkAndGetColumn<ColumnNullable>(*arg_cond.column ))
687+ const auto & column1 = block.getByPosition (arguments[1 ]);
688+ const auto & column2 = block.getByPosition (arguments[2 ]);
689+ auto & result_column = block.getByPosition (result);
690+
691+ if (cond_is_true)
671692 {
693+ if (result_column.type ->equals (*column1.type ))
694+ {
695+ result_column.column = std::move (column1.column );
696+ return true ;
697+ }
698+ }
699+ else if (cond_is_false || cond_is_null)
700+ {
701+ if (result_column.type ->equals (*column2.type ))
702+ {
703+ result_column.column = std::move (column2.column );
704+ return true ;
705+ }
706+ }
707+
708+ if (const auto * nullable = checkAndGetColumn<ColumnNullable>(*not_const_condition))
709+ {
710+ ColumnPtr new_cond_column = nullable->getNestedColumnPtr ();
711+ size_t column_size = arg_cond.column ->size ();
712+
713+ if (cond_is_null || cond_is_true || cond_is_false) // / Nullable(Nothing) or consts
714+ {
715+ UInt8 value = cond_is_true ? 1 : 0 ;
716+ new_cond_column = ColumnConst::create (ColumnUInt8::create (1 , value), column_size);
717+ }
718+ else if (checkAndGetColumn<ColumnUInt8>(*new_cond_column))
719+ {
720+ auto nested_column_copy = new_cond_column->cloneResized (new_cond_column->size ());
721+ typeid_cast<ColumnUInt8 *>(nested_column_copy.get ())->applyZeroMap (nullable->getNullMapData ());
722+ new_cond_column = std::move (nested_column_copy);
723+
724+ if (cond_is_const)
725+ new_cond_column = ColumnConst::create (new_cond_column, column_size);
726+ }
727+ else
728+ throw Exception (" Illegal column " + arg_cond.column ->getName () + " of " + getName () + " condition" ,
729+ ErrorCodes::ILLEGAL_COLUMN);
730+
672731 Block temporary_block
673732 {
674- { nullable-> getNestedColumnPtr () , removeNullable (arg_cond.type ), arg_cond.name },
675- block. getByPosition (arguments[ 1 ]) ,
676- block. getByPosition (arguments[ 2 ]) ,
677- block. getByPosition (result)
733+ { new_cond_column , removeNullable (arg_cond.type ), arg_cond.name },
734+ column1 ,
735+ column2 ,
736+ result_column
678737 };
679738
680739 executeImpl (temporary_block, {0 , 1 , 2 }, 3 , temporary_block.rows ());
681740
682- block. getByPosition (result) .column = std::move (temporary_block.getByPosition (3 ).column );
741+ result_column .column = std::move (temporary_block.getByPosition (3 ).column );
683742 return true ;
684743 }
685744
@@ -916,7 +975,7 @@ class FunctionIf : public FunctionIfBase</*null_is_false=*/false>
916975
917976 void executeImpl (Block & block, const ColumnNumbers & arguments, size_t result, size_t input_rows_count) override
918977 {
919- if (executeForNullableCondition (block, arguments, result, input_rows_count)
978+ if (executeForConstAndNullableCondition (block, arguments, result, input_rows_count)
920979 || executeForNullThenElse (block, arguments, result, input_rows_count)
921980 || executeForNullableThenElse (block, arguments, result, input_rows_count))
922981 return ;
@@ -964,10 +1023,7 @@ class FunctionIf : public FunctionIfBase</*null_is_false=*/false>
9641023 using T0 = typename Types::LeftType;
9651024 using T1 = typename Types::RightType;
9661025
967- if constexpr (IsDecimalNumber<T0> == IsDecimalNumber<T1>)
968- return executeTyped<T0, T1>(cond_col, block, arguments, result, input_rows_count);
969- else
970- throw Exception (" Conditional function with Decimal and non Decimal" , ErrorCodes::NOT_IMPLEMENTED);
1026+ return executeTyped<T0, T1>(cond_col, block, arguments, result, input_rows_count);
9711027 };
9721028
9731029 TypeIndex left_id = arg_then.type ->getTypeId ();
0 commit comments