Skip to content

Commit 85c6f70

Browse files
committed
[Feature](func) Support function PERIOD_ADD and PERIOD_DIFF (apache#56945)
```text mysql> SELECT PERIOD_ADD(2512, 1); +---------------------+ | PERIOD_ADD(2512, 1) | +---------------------+ | 202601 | +---------------------+ mysql> SELECT PERIOD_ADD(6901, 1); +---------------------+ | PERIOD_ADD(6901, 1) | +---------------------+ | 206902 | +---------------------+ mysql> SELECT PERIOD_ADD(7001, 1); +---------------------+ | PERIOD_ADD(7001, 1) | +---------------------+ | 197002 | +---------------------+ mysql> SELECT PERIOD_DIFF(2510, 2501); +-------------------------+ | PERIOD_DIFF(2510, 2501) | +-------------------------+ | 9 | +-------------------------+ mysql> SELECT PERIOD_DIFF(2501, 2510); +-------------------------+ | PERIOD_DIFF(2501, 2510) | +-------------------------+ | -9 | +-------------------------+ ```
1 parent f124c70 commit 85c6f70

File tree

13 files changed

+590
-82
lines changed

13 files changed

+590
-82
lines changed

be/src/vec/functions/function.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
#include "common/logging.h"
3333
#include "common/status.h"
3434
#include "olap/rowset/segment_v2/inverted_index_iterator.h" // IWYU pragma: keep
35+
#include "runtime/define_primitive_type.h"
3536
#include "udf/udf.h"
3637
#include "vec/core/block.h"
3738
#include "vec/core/column_numbers.h"

be/src/vec/functions/function_date_or_datetime_computation.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@ using FunctionSecToTime = FunctionCurrentDateOrDateTime<SecToTimeImpl>;
6767
using FunctionMicroSecToDateTime = TimestampToDateTime<MicroSec>;
6868
using FunctionMilliSecToDateTime = TimestampToDateTime<MilliSec>;
6969
using FunctionSecToDateTime = TimestampToDateTime<Sec>;
70+
using FunctionPeriodAdd = FunctionNeedsToHandleNull<PeriodAddImpl, PrimitiveType::TYPE_BIGINT>;
71+
using FunctionPeriodDiff = FunctionNeedsToHandleNull<PeriodDiffImpl, PrimitiveType::TYPE_BIGINT>;
7072

7173
void register_function_date_time_computation(SimpleFunctionFactory& factory) {
7274
factory.register_function<FunctionDateDiff>();
@@ -99,6 +101,8 @@ void register_function_date_time_computation(SimpleFunctionFactory& factory) {
99101
factory.register_function<FunctionMonthsBetween>();
100102
factory.register_function<FunctionTime>();
101103
factory.register_function<FunctionGetFormat>();
104+
factory.register_function<FunctionPeriodAdd>();
105+
factory.register_function<FunctionPeriodDiff>();
102106

103107
// alias
104108
factory.register_alias("days_add", "date_add");

be/src/vec/functions/function_date_or_datetime_computation.h

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
#include "vec/functions/datetime_errors.h"
6060
#include "vec/functions/function.h"
6161
#include "vec/functions/function_helpers.h"
62+
#include "vec/functions/function_needs_to_handle_null.h"
6263
#include "vec/runtime/time_value.h"
6364
#include "vec/runtime/vdatetime_value.h"
6465
#include "vec/utils/util.hpp"
@@ -1529,5 +1530,94 @@ class FunctionGetFormat : public IFunction {
15291530
static constexpr auto TIME_NAME = "TIME";
15301531
};
15311532

1533+
class PeriodHelper {
1534+
public:
1535+
// For two digit year, 70-99 -> 1970-1999, 00-69 -> 2000-2069
1536+
// this rule is same as MySQL
1537+
static constexpr int YY_PART_YEAR = 70;
1538+
static Status valid_period(int64_t period) {
1539+
if (period <= 0 || (period % 100) == 0 || (period % 100) > 12) {
1540+
return Status::InvalidArgument("Period function got invalid period: {}", period);
1541+
}
1542+
return Status::OK();
1543+
}
1544+
1545+
static int64_t check_and_convert_period_to_month(uint64_t period) {
1546+
THROW_IF_ERROR(valid_period(period));
1547+
uint64_t year = period / 100;
1548+
if (year < 100) {
1549+
year += (year >= YY_PART_YEAR) ? 1900 : 2000;
1550+
}
1551+
return year * 12LL + (period % 100) - 1;
1552+
}
1553+
1554+
static int64_t convert_month_to_period(uint64_t month) {
1555+
uint64_t year = month / 12;
1556+
if (year < 100) {
1557+
year += (year >= YY_PART_YEAR) ? 1900 : 2000;
1558+
}
1559+
return year * 100 + month % 12 + 1;
1560+
}
1561+
};
1562+
1563+
class PeriodAddImpl {
1564+
public:
1565+
static constexpr auto name = "period_add";
1566+
static size_t get_number_of_arguments() { return 2; }
1567+
static DataTypePtr get_return_type_impl(const DataTypes& arguments) {
1568+
return std::make_shared<DataTypeInt64>();
1569+
}
1570+
1571+
static void execute(const std::vector<ColumnWithConstAndNullMap>& cols_info,
1572+
ColumnInt64::MutablePtr& res_col, PaddedPODArray<UInt8>& res_null_map_data,
1573+
size_t input_rows_count) {
1574+
const auto& left_data =
1575+
assert_cast<const ColumnInt64*>(cols_info[0].nested_col)->get_data();
1576+
const auto& right_data =
1577+
assert_cast<const ColumnInt64*>(cols_info[1].nested_col)->get_data();
1578+
for (size_t i = 0; i < input_rows_count; ++i) {
1579+
if (cols_info[0].is_null_at(i) || cols_info[1].is_null_at(i)) {
1580+
res_col->insert_default();
1581+
res_null_map_data[i] = 1;
1582+
continue;
1583+
}
1584+
1585+
int64_t period = left_data[index_check_const(i, cols_info[0].is_const)];
1586+
int64_t months = right_data[index_check_const(i, cols_info[1].is_const)];
1587+
res_col->insert_value(PeriodHelper::convert_month_to_period(
1588+
PeriodHelper::check_and_convert_period_to_month(period) + months));
1589+
}
1590+
}
1591+
};
1592+
class PeriodDiffImpl {
1593+
public:
1594+
static constexpr auto name = "period_diff";
1595+
static size_t get_number_of_arguments() { return 2; }
1596+
static DataTypePtr get_return_type_impl(const DataTypes& arguments) {
1597+
return std::make_shared<DataTypeInt64>();
1598+
}
1599+
1600+
static void execute(const std::vector<ColumnWithConstAndNullMap>& cols_info,
1601+
ColumnInt64::MutablePtr& res_col, PaddedPODArray<UInt8>& res_null_map_data,
1602+
size_t input_rows_count) {
1603+
const auto& left_data =
1604+
assert_cast<const ColumnInt64*>(cols_info[0].nested_col)->get_data();
1605+
const auto& right_data =
1606+
assert_cast<const ColumnInt64*>(cols_info[1].nested_col)->get_data();
1607+
for (size_t i = 0; i < input_rows_count; ++i) {
1608+
if (cols_info[0].is_null_at(i) || cols_info[1].is_null_at(i)) {
1609+
res_col->insert_default();
1610+
res_null_map_data[i] = 1;
1611+
continue;
1612+
}
1613+
1614+
int64_t period1 = left_data[index_check_const(i, cols_info[0].is_const)];
1615+
int64_t period2 = right_data[index_check_const(i, cols_info[1].is_const)];
1616+
res_col->insert_value(PeriodHelper::check_and_convert_period_to_month(period1) -
1617+
PeriodHelper::check_and_convert_period_to_month(period2));
1618+
}
1619+
}
1620+
};
1621+
15321622
#include "common/compile_check_avoid_end.h"
15331623
} // namespace doris::vectorized
Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
#pragma once
18+
#include <boost/mpl/aux_/na_fwd.hpp>
19+
20+
#include "vec/functions/function.h"
21+
22+
namespace doris::vectorized {
23+
#include "common/compile_check_begin.h"
24+
25+
// Helper struct to store information about const+nullable columns
26+
struct ColumnWithConstAndNullMap {
27+
const IColumn* nested_col = nullptr;
28+
const NullMap* null_map = nullptr;
29+
bool is_const = false;
30+
31+
bool is_null_at(size_t row) const { return (null_map && (*null_map)[is_const ? 0 : row]); }
32+
};
33+
34+
// For functions that need to handle const+nullable column combinations
35+
// means that functioin `use_default_implementation_for_nulls()` returns false
36+
template <typename Impl, PrimitiveType ResultPrimitiveType>
37+
class FunctionNeedsToHandleNull : public IFunction {
38+
public:
39+
using ResultColumnType = PrimitiveTypeTraits<ResultPrimitiveType>::ColumnType;
40+
41+
static constexpr auto name = Impl::name;
42+
String get_name() const override { return name; }
43+
44+
static std::shared_ptr<IFunction> create() {
45+
return std::make_shared<FunctionNeedsToHandleNull>();
46+
}
47+
48+
size_t get_number_of_arguments() const override { return Impl::get_number_of_arguments(); }
49+
50+
bool is_variadic() const override {
51+
if constexpr (requires { Impl::is_variadic(); }) {
52+
return Impl::is_variadic();
53+
}
54+
return false;
55+
}
56+
57+
bool use_default_implementation_for_nulls() const override { return false; }
58+
59+
DataTypePtr get_return_type_impl(const DataTypes& arguments) const override {
60+
return Impl::get_return_type_impl(arguments);
61+
}
62+
63+
Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments,
64+
uint32_t result, size_t input_rows_count) const override {
65+
auto res_col = ResultColumnType::create();
66+
auto null_map = ColumnUInt8::create();
67+
auto& null_map_data = null_map->get_data();
68+
res_col->reserve(input_rows_count);
69+
null_map_data.resize_fill(input_rows_count, 0);
70+
71+
const size_t arg_size = arguments.size();
72+
73+
std::vector<ColumnWithConstAndNullMap> columns_info;
74+
columns_info.resize(arg_size);
75+
bool has_nullable = false;
76+
collect_columns_info(columns_info, block, arguments, has_nullable);
77+
78+
// Check if there is a const null
79+
for (size_t i = 0; i < arg_size; ++i) {
80+
if (columns_info[i].is_const && columns_info[i].null_map &&
81+
(*columns_info[i].null_map)[0] &&
82+
execute_const_null(res_col, null_map_data, input_rows_count, i)) {
83+
block.replace_by_position(
84+
result, ColumnNullable::create(std::move(res_col), std::move(null_map)));
85+
return Status::OK();
86+
}
87+
}
88+
89+
Impl::execute(columns_info, res_col, null_map_data, input_rows_count);
90+
91+
if (is_return_nullable(has_nullable, columns_info)) {
92+
block.replace_by_position(
93+
result, ColumnNullable::create(std::move(res_col), std::move(null_map)));
94+
} else {
95+
block.replace_by_position(result, std::move(res_col));
96+
}
97+
98+
return Status::OK();
99+
}
100+
101+
private:
102+
// Handle a NULL literal
103+
// Default behavior is fill result with all NULLs
104+
// return true when the res_col is ready to be written back to the block without further processing
105+
bool execute_const_null(typename ResultColumnType::MutablePtr& res_col,
106+
PaddedPODArray<UInt8>& res_null_map_data, size_t input_rows_count,
107+
size_t null_index) const {
108+
if constexpr (requires {
109+
Impl::execute_const_null(res_col, res_null_map_data, input_rows_count,
110+
null_index);
111+
}) {
112+
return Impl::execute_const_null(res_col, res_null_map_data, input_rows_count,
113+
null_index);
114+
}
115+
116+
res_col->insert_many_defaults(input_rows_count);
117+
res_null_map_data.assign(input_rows_count, (UInt8)1);
118+
119+
return true;
120+
}
121+
122+
// Collect the required information for each column into columns_info
123+
// Including whether it is a constant column, nested column and null map(if exists).
124+
void collect_columns_info(std::vector<ColumnWithConstAndNullMap>& columns_info,
125+
const Block& block, const ColumnNumbers& arguments,
126+
bool& has_nullable) const {
127+
for (size_t i = 0; i < arguments.size(); ++i) {
128+
ColumnPtr col_ptr;
129+
const auto& col_with_type = block.get_by_position(arguments[i]);
130+
std::tie(col_ptr, columns_info[i].is_const) = unpack_if_const(col_with_type.column);
131+
132+
if (is_column_nullable(*col_ptr)) {
133+
has_nullable = true;
134+
const auto* nullable = check_and_get_column<ColumnNullable>(col_ptr.get());
135+
columns_info[i].nested_col = &nullable->get_nested_column();
136+
columns_info[i].null_map = &nullable->get_null_map_data();
137+
} else {
138+
columns_info[i].nested_col = col_ptr.get();
139+
}
140+
}
141+
}
142+
143+
// Determine if the return type should be wrapped in nullable
144+
// Default behavior is return nullable if any argument is nullable
145+
bool is_return_nullable(bool has_nullable,
146+
const std::vector<ColumnWithConstAndNullMap>& cols_info) const {
147+
if constexpr (requires { Impl::is_return_nullable(has_nullable, cols_info); }) {
148+
return Impl::is_return_nullable(has_nullable, cols_info);
149+
}
150+
return has_nullable;
151+
}
152+
};
153+
#include "common/compile_check_end.h"
154+
} // namespace doris::vectorized

be/src/vec/functions/function_string.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1334,6 +1334,8 @@ using FunctionStringAppendTrailingCharIfAbsent =
13341334
using FunctionStringLPad = FunctionStringPad<StringLPad>;
13351335
using FunctionStringRPad = FunctionStringPad<StringRPad>;
13361336

1337+
using FunctionMakeSet = FunctionNeedsToHandleNull<MakeSetImpl, PrimitiveType::TYPE_STRING>;
1338+
13371339
void register_function_string(SimpleFunctionFactory& factory) {
13381340
factory.register_function<FunctionStringParseDataSize>();
13391341
factory.register_function<FunctionStringASCII>();

0 commit comments

Comments
 (0)