@@ -27,6 +27,14 @@ namespace ErrorCodes
2727
2828class FunctionToStartOfInterval : public IFunction
2929{
30+ private:
31+ enum class Overload
32+ {
33+ Default, // / toStartOfInterval(time, interval) or toStartOfInterval(time, interval, timezone)
34+ Origin // / toStartOfInterval(time, interval, origin) or toStartOfInterval(time, interval, origin, timezone)
35+ };
36+ mutable Overload overload;
37+
3038public:
3139 static FunctionPtr create (ContextPtr) { return std::make_shared<FunctionToStartOfInterval>(); }
3240
@@ -36,7 +44,7 @@ class FunctionToStartOfInterval : public IFunction
3644 size_t getNumberOfArguments () const override { return 0 ; }
3745 bool isSuitableForShortCircuitArgumentsExecution (const DataTypesWithConstInfo & /* arguments*/ ) const override { return false ; }
3846 bool useDefaultImplementationForConstants () const override { return true ; }
39- ColumnNumbers getArgumentsThatAreAlwaysConstant () const override { return {1 , 2 }; }
47+ ColumnNumbers getArgumentsThatAreAlwaysConstant () const override { return {1 , 2 , 3 }; }
4048 bool hasInformationAboutMonotonicity () const override { return true ; }
4149 Monotonicity getMonotonicityForRange (const IDataType &, const Field &, const Field &) const override { return { .is_monotonic = true , .is_always_monotonic = true }; }
4250
@@ -96,13 +104,49 @@ class FunctionToStartOfInterval : public IFunction
96104 auto check_third_argument = [&]
97105 {
98106 const DataTypePtr & type_arg3 = arguments[2 ].type ;
99- if (!isString (type_arg3))
107+
108+ if (isString (type_arg3))
109+ {
110+ overload = Overload::Default;
111+
112+ if (value_is_date && result_type == ResultType::Date) // / weird why this is && instead of || but too afraid to change it
113+ throw Exception (ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
114+ " The timezone argument of function {} with interval type {} is allowed only when the 1st argument has the type DateTime or DateTime64" ,
115+ getName (), interval_type->getKind ().toString ());
116+ }
117+ else if (isDate (type_arg3) || isDateTime (type_arg3) || isDateTime64 (type_arg3))
118+ {
119+ overload = Overload::Origin;
120+
121+ // / For simplicity, require the time and origin arguments have the same type and scale
122+ const DataTypePtr & type_arg1 = arguments[0 ].type ;
123+
124+ if (!type_arg1->equals (*type_arg3))
125+ throw Exception (ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, " 1st and 3rd argument for function {} must have the same type" , getName ());
126+ // / As a lemma, if both types are DateTime64, they will have the same scale
127+ }
128+ else
100129 throw Exception (ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
101- " Illegal type {} of 3rd argument of function {}, expected a constant timezone string" ,
130+ " Illegal type {} of 3rd argument of function {}, expected constant timezone string or constant origin of type Date, DateTime or DateTime64 " ,
102131 type_arg3->getName (), getName ());
103- if (value_is_date && result_type == ResultType::Date) // / weird why this is && instead of || but too afraid to change it
132+ };
133+
134+ auto check_fourth_argument = [&]
135+ {
136+ if (overload == Overload::Default) // / sanity check
137+ throw Exception (ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
138+ " Illegal type {} of 3rd argument of function {}, expected a const origin of type Date, DateTime or DateTime64" ,
139+ arguments[2 ].type ->getName (), getName ());
140+
141+ const DataTypePtr & type_arg4 = arguments[3 ].type ;
142+ if (!isString (type_arg4))
104143 throw Exception (ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
105- " The timezone argument of function {} with interval type {} is allowed only when the 1st argument has type DateTime or DateTimt64" ,
144+ " Illegal type {} of 4th argument of function {}, expected constant timezone string" ,
145+ type_arg4->getName (), getName ());
146+
147+ if (value_is_date || result_type == ResultType::Date) // / same as in check_third_argument()
148+ throw Exception (ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
149+ " The timezone argument of function {} with interval type {} is allowed only when the 1st argument has the type DateTime or DateTime64" ,
106150 getName (), interval_type->getKind ().toString ());
107151 };
108152
@@ -117,21 +161,34 @@ class FunctionToStartOfInterval : public IFunction
117161 check_second_argument ();
118162 check_third_argument ();
119163 }
164+ else if (arguments.size () == 4 )
165+ {
166+ check_first_argument ();
167+ check_second_argument ();
168+ check_third_argument ();
169+ check_fourth_argument ();
170+ }
120171 else
121172 {
122173 throw Exception (ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
123- " Number of arguments for function {} doesn't match: passed {}, should be 2 or 3 " ,
174+ " Number of arguments for function {} doesn't match: passed {}, should be 2, 3 or 4 " ,
124175 getName (), arguments.size ());
125176 }
126177
127178 switch (result_type)
128179 {
129180 case ResultType::Date:
181+ {
130182 return std::make_shared<DataTypeDate>();
183+ }
131184 case ResultType::DateTime:
132- return std::make_shared<DataTypeDateTime>(extractTimeZoneNameFromFunctionArguments (arguments, 2 , 0 , false ));
185+ {
186+ const size_t time_zone_arg_num = (overload == Overload::Default) ? 2 : 3 ;
187+ return std::make_shared<DataTypeDateTime>(extractTimeZoneNameFromFunctionArguments (arguments, time_zone_arg_num, 0 , false ));
188+ }
133189 case ResultType::DateTime64:
134190 {
191+ // / TODO complex stuff is added here
135192 UInt32 scale = 0 ;
136193 if (interval_type->getKind () == IntervalKind::Nanosecond)
137194 scale = 9 ;
@@ -140,7 +197,8 @@ class FunctionToStartOfInterval : public IFunction
140197 else if (interval_type->getKind () == IntervalKind::Millisecond)
141198 scale = 3 ;
142199
143- return std::make_shared<DataTypeDateTime64>(scale, extractTimeZoneNameFromFunctionArguments (arguments, 2 , 0 , false ));
200+ const size_t time_zone_arg_num = (overload == Overload::Default) ? 2 : 3 ;
201+ return std::make_shared<DataTypeDateTime64>(scale, extractTimeZoneNameFromFunctionArguments (arguments, time_zone_arg_num, 0 , false ));
144202 }
145203 }
146204
@@ -151,14 +209,21 @@ class FunctionToStartOfInterval : public IFunction
151209 {
152210 const auto & time_column = arguments[0 ];
153211 const auto & interval_column = arguments[1 ];
154- const auto & time_zone = extractTimeZoneFromFunctionArguments (arguments, 2 , 0 );
155- auto result_column = dispatchForTimeColumn (time_column, interval_column, result_type, time_zone);
212+
213+ ColumnWithTypeAndName origin_column;
214+ if (overload == Overload::Origin)
215+ origin_column = arguments[2 ];
216+
217+ const size_t time_zone_arg_num = (overload == Overload::Origin) ? 3 : 2 ;
218+ const auto & time_zone = extractTimeZoneFromFunctionArguments (arguments, time_zone_arg_num, 0 );
219+
220+ auto result_column = dispatchForTimeColumn (time_column, interval_column, origin_column, result_type, time_zone);
156221 return result_column;
157222 }
158223
159224private:
160225 ColumnPtr dispatchForTimeColumn (
161- const ColumnWithTypeAndName & time_column, const ColumnWithTypeAndName & interval_column,
226+ const ColumnWithTypeAndName & time_column, const ColumnWithTypeAndName & interval_column, const ColumnWithTypeAndName & origin_column,
162227 const DataTypePtr & result_type, const DateLUTImpl & time_zone) const
163228 {
164229 const auto & time_column_type = *time_column.type .get ();
@@ -170,26 +235,26 @@ class FunctionToStartOfInterval : public IFunction
170235 auto scale = assert_cast<const DataTypeDateTime64 &>(time_column_type).getScale ();
171236
172237 if (time_column_vec)
173- return dispatchForIntervalColumn (assert_cast<const DataTypeDateTime64 &>(time_column_type), *time_column_vec, interval_column, result_type, time_zone, scale);
238+ return dispatchForIntervalColumn (assert_cast<const DataTypeDateTime64 &>(time_column_type), *time_column_vec, interval_column, origin_column, result_type, time_zone, scale);
174239 }
175240 else if (isDateTime (time_column_type))
176241 {
177242 const auto * time_column_vec = checkAndGetColumn<ColumnDateTime>(time_column_col);
178243 if (time_column_vec)
179- return dispatchForIntervalColumn (assert_cast<const DataTypeDateTime &>(time_column_type), *time_column_vec, interval_column, result_type, time_zone);
244+ return dispatchForIntervalColumn (assert_cast<const DataTypeDateTime &>(time_column_type), *time_column_vec, interval_column, origin_column, result_type, time_zone);
180245 }
181246 else if (isDate (time_column_type))
182247 {
183248 const auto * time_column_vec = checkAndGetColumn<ColumnDate>(time_column_col);
184249 if (time_column_vec)
185- return dispatchForIntervalColumn (assert_cast<const DataTypeDate &>(time_column_type), *time_column_vec, interval_column, result_type, time_zone);
250+ return dispatchForIntervalColumn (assert_cast<const DataTypeDate &>(time_column_type), *time_column_vec, interval_column, origin_column, result_type, time_zone);
186251 }
187252 throw Exception (ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, " Illegal column for 1st argument of function {}, expected a Date, DateTime or DateTime64" , getName ());
188253 }
189254
190255 template <typename TimeDataType, typename TimeColumnType>
191256 ColumnPtr dispatchForIntervalColumn (
192- const TimeDataType & time_data_type, const TimeColumnType & time_column, const ColumnWithTypeAndName & interval_column,
257+ const TimeDataType & time_data_type, const TimeColumnType & time_column, const ColumnWithTypeAndName & interval_column, const ColumnWithTypeAndName & origin_column,
193258 const DataTypePtr & result_type, const DateLUTImpl & time_zone, UInt16 scale = 1 ) const
194259 {
195260 const auto * interval_type = checkAndGetDataType<DataTypeInterval>(interval_column.type .get ());
@@ -207,35 +272,35 @@ class FunctionToStartOfInterval : public IFunction
207272 switch (interval_type->getKind ())
208273 {
209274 case IntervalKind::Nanosecond:
210- return execute<TimeDataType, TimeColumnType, DataTypeDateTime64, IntervalKind::Nanosecond>(time_data_type, time_column, num_units, result_type, time_zone, scale);
275+ return execute<TimeDataType, TimeColumnType, DataTypeDateTime64, IntervalKind::Nanosecond>(time_data_type, time_column, num_units, origin_column, result_type, time_zone, scale);
211276 case IntervalKind::Microsecond:
212- return execute<TimeDataType, TimeColumnType, DataTypeDateTime64, IntervalKind::Microsecond>(time_data_type, time_column, num_units, result_type, time_zone, scale);
277+ return execute<TimeDataType, TimeColumnType, DataTypeDateTime64, IntervalKind::Microsecond>(time_data_type, time_column, num_units, origin_column, result_type, time_zone, scale);
213278 case IntervalKind::Millisecond:
214- return execute<TimeDataType, TimeColumnType, DataTypeDateTime64, IntervalKind::Millisecond>(time_data_type, time_column, num_units, result_type, time_zone, scale);
279+ return execute<TimeDataType, TimeColumnType, DataTypeDateTime64, IntervalKind::Millisecond>(time_data_type, time_column, num_units, origin_column, result_type, time_zone, scale);
215280 case IntervalKind::Second:
216- return execute<TimeDataType, TimeColumnType, DataTypeDateTime, IntervalKind::Second>(time_data_type, time_column, num_units, result_type, time_zone, scale);
281+ return execute<TimeDataType, TimeColumnType, DataTypeDateTime, IntervalKind::Second>(time_data_type, time_column, num_units, origin_column, result_type, time_zone, scale);
217282 case IntervalKind::Minute:
218- return execute<TimeDataType, TimeColumnType, DataTypeDateTime, IntervalKind::Minute>(time_data_type, time_column, num_units, result_type, time_zone, scale);
283+ return execute<TimeDataType, TimeColumnType, DataTypeDateTime, IntervalKind::Minute>(time_data_type, time_column, num_units, origin_column, result_type, time_zone, scale);
219284 case IntervalKind::Hour:
220- return execute<TimeDataType, TimeColumnType, DataTypeDateTime, IntervalKind::Hour>(time_data_type, time_column, num_units, result_type, time_zone, scale);
285+ return execute<TimeDataType, TimeColumnType, DataTypeDateTime, IntervalKind::Hour>(time_data_type, time_column, num_units, origin_column, result_type, time_zone, scale);
221286 case IntervalKind::Day:
222- return execute<TimeDataType, TimeColumnType, DataTypeDateTime, IntervalKind::Day>(time_data_type, time_column, num_units, result_type, time_zone, scale);
287+ return execute<TimeDataType, TimeColumnType, DataTypeDateTime, IntervalKind::Day>(time_data_type, time_column, num_units, origin_column, result_type, time_zone, scale);
223288 case IntervalKind::Week:
224- return execute<TimeDataType, TimeColumnType, DataTypeDate, IntervalKind::Week>(time_data_type, time_column, num_units, result_type, time_zone, scale);
289+ return execute<TimeDataType, TimeColumnType, DataTypeDate, IntervalKind::Week>(time_data_type, time_column, num_units, origin_column, result_type, time_zone, scale);
225290 case IntervalKind::Month:
226- return execute<TimeDataType, TimeColumnType, DataTypeDate, IntervalKind::Month>(time_data_type, time_column, num_units, result_type, time_zone, scale);
291+ return execute<TimeDataType, TimeColumnType, DataTypeDate, IntervalKind::Month>(time_data_type, time_column, num_units, origin_column, result_type, time_zone, scale);
227292 case IntervalKind::Quarter:
228- return execute<TimeDataType, TimeColumnType, DataTypeDate, IntervalKind::Quarter>(time_data_type, time_column, num_units, result_type, time_zone, scale);
293+ return execute<TimeDataType, TimeColumnType, DataTypeDate, IntervalKind::Quarter>(time_data_type, time_column, num_units, origin_column, result_type, time_zone, scale);
229294 case IntervalKind::Year:
230- return execute<TimeDataType, TimeColumnType, DataTypeDate, IntervalKind::Year>(time_data_type, time_column, num_units, result_type, time_zone, scale);
295+ return execute<TimeDataType, TimeColumnType, DataTypeDate, IntervalKind::Year>(time_data_type, time_column, num_units, origin_column, result_type, time_zone, scale);
231296 }
232297
233298 std::unreachable ();
234299 }
235300
236301 template <typename TimeDataType, typename TimeColumnType, typename ResultDataType, IntervalKind::Kind unit>
237302 ColumnPtr execute (
238- const TimeDataType &, const TimeColumnType & time_column_type, Int64 num_units,
303+ const TimeDataType & time_data_type , const TimeColumnType & time_column_type, Int64 num_units, const ColumnWithTypeAndName & origin_column ,
239304 const DataTypePtr & result_type, const DateLUTImpl & time_zone, UInt16 scale) const
240305 {
241306 using ResultColumnType = typename ResultDataType::ColumnType;
@@ -251,8 +316,39 @@ class FunctionToStartOfInterval : public IFunction
251316
252317 Int64 scale_multiplier = DecimalUtils::scaleMultiplier<DateTime64>(scale);
253318
254- for (size_t i = 0 ; i != size; ++i)
255- result_data[i] = static_cast <ResultFieldType>(ToStartOfInterval<unit>::execute (time_data[i], num_units, time_zone, scale_multiplier));
319+ if (origin_column.column == nullptr )
320+ {
321+ // / Default overload
322+ for (size_t i = 0 ; i != size; ++i)
323+ result_data[i] = static_cast <ResultFieldType>(ToStartOfInterval<unit>::execute (time_data[i], num_units, time_zone, scale_multiplier));
324+ }
325+ else
326+ {
327+ // / Origin overload
328+ static constexpr size_t SECONDS_PER_DAY = 86400 ;
329+
330+ UInt64 origin = origin_column.column ->get64 (0 );
331+ for (size_t i = 0 ; i != size; ++i)
332+ {
333+ auto t = time_data[i];
334+ if (origin > static_cast <UInt64>(t))
335+ throw Exception (ErrorCodes::BAD_ARGUMENTS, " The origin value must be earlier than the time value" );
336+
337+ t -= origin;
338+ auto res = static_cast <ResultFieldType>(ToStartOfInterval<unit>::execute (t, num_units, time_zone, scale_multiplier));
339+
340+ result_data[i] = 0 ;
341+
342+ if (isDateTime (time_data_type) && (unit == IntervalKind::Week || unit == IntervalKind::Month || unit == IntervalKind::Quarter || unit == IntervalKind::Year))
343+ {
344+ result_data[i] += res + (origin / SECONDS_PER_DAY);
345+ }
346+ else
347+ {
348+ result_data[i] += res + origin;
349+ }
350+ }
351+ }
256352
257353 return result_col;
258354 }
0 commit comments