Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 130 additions & 0 deletions be/src/vec/functions/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include "common/logging.h"
#include "common/status.h"
#include "olap/rowset/segment_v2/inverted_index_iterator.h" // IWYU pragma: keep
#include "runtime/define_primitive_type.h"
#include "udf/udf.h"
#include "vec/core/block.h"
#include "vec/core/column_numbers.h"
Expand Down Expand Up @@ -538,6 +539,135 @@ class DefaultFunctionBuilder : public FunctionBuilderImpl {
std::shared_ptr<IFunction> function;
};

// Helper struct to store information about const+nullable columns
struct ColumnWithConstAndNullMap {
const IColumn* nested_col = nullptr;
const NullMap* null_map = nullptr;
bool is_const = false;

bool is_null_at(size_t row) const { return (null_map && (*null_map)[is_const ? 0 : row]); }
};

// For functions that need to handle const+nullable column combinations
// means that functioin `use_default_implementation_for_nulls()` returns false
template <typename Impl, PrimitiveType ResultPrimitiveType>
class FunctionNeedsToHandleNull : public IFunction {
public:
using ResultColumnType = PrimitiveTypeTraits<ResultPrimitiveType>::ColumnType;

static constexpr auto name = Impl::name;
String get_name() const override { return name; }

static std::shared_ptr<IFunction> create() {
return std::make_shared<FunctionNeedsToHandleNull>();
}

size_t get_number_of_arguments() const override { return Impl::get_number_of_arguments(); }

bool is_variadic() const override {
if constexpr (requires { Impl::is_variadic(); }) {
return Impl::is_variadic();
}
return false;
}

bool use_default_implementation_for_nulls() const override { return false; }

DataTypePtr get_return_type_impl(const DataTypes& arguments) const override {
return Impl::get_return_type_impl(arguments);
}

Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments,
uint32_t result, size_t input_rows_count) const override {
auto res_col = ResultColumnType::create();
auto null_map = ColumnUInt8::create();
auto& null_map_data = null_map->get_data();
res_col->reserve(input_rows_count);
null_map_data.resize_fill(input_rows_count, 0);

const size_t arg_size = arguments.size();

std::vector<ColumnsWithConstAndNullMap> columns_info;
columns_info.resize(arg_size);
bool has_nullable = false;
collect_columns_info(columns_info, block, arguments, has_nullable);

// Check if there is a const null
for (size_t i = 0; i < arg_size; ++i) {
if (columns_info[i].is_const && columns_info[i].null_map &&
(*columns_info[i].null_map)[0] &&
execute_const_null(res_col, null_map_data, input_rows_count, i)) {
block.replace_by_position(
result, ColumnNullable::create(std::move(res_col), std::move(null_map)));
return Status::OK();
}
}

Impl::execute(columns_info, res_col, null_map_data, input_rows_count);

if (is_return_nullable(has_nullable, columns_info)) {
block.replace_by_position(
result, ColumnNullable::create(std::move(res_col), std::move(null_map)));
} else {
block.replace_by_position(result, std::move(res_col));
}

return Status::OK();
}

private:
// Handle a NULL literal
// Default behavior is fill result with all NULLs
// return true when the res_col is ready to be written back to the block without further processing
bool execute_const_null(typename ResultColumnType::MutablePtr& res_col,
PaddedPODArray<UInt8>& res_null_map_data, size_t input_rows_count,
size_t null_index) const {
if constexpr (requires {
Impl::execute_const_null(res_col, res_null_map_data, input_rows_count,
null_index);
}) {
return Impl::execute_const_null(res_col, res_null_map_data, input_rows_count,
null_index);
}

res_col->insert_many_defaults(input_rows_count);
res_null_map_data.assign(input_rows_count, (UInt8)1);

return true;
}

// Collect the required information for each column into columns_info
// Including whether it is a constant column, nested column and null map(if exists).
void collect_columns_info(std::vector<ColumnsWithConstAndNullMap>& columns_info,
const Block& block, const ColumnNumbers& arguments,
bool& has_nullable) const {
for (size_t i = 0; i < arguments.size(); ++i) {
ColumnPtr col_ptr;
const auto& col_with_type = block.get_by_position(arguments[i]);
std::tie(col_ptr, columns_info[i].is_const) = unpack_if_const(col_with_type.column);

if (is_column_nullable(*col_ptr)) {
has_nullable = true;
const auto* nullable = check_and_get_column<ColumnNullable>(col_ptr.get());
columns_info[i].nested_col = &nullable->get_nested_column();
columns_info[i].null_map = &nullable->get_null_map_data();
} else {
columns_info[i].nested_col = col_ptr.get();
}
}
}

// Determine if the return type should be wrapped in nullable
// Default behavior is return nullable if any argument is nullable
bool is_return_nullable(bool has_nullable,
const std::vector<ColumnsWithConstAndNullMap>& cols_info) const {
if constexpr (requires { Impl::is_return_nullable(has_nullable, cols_info); }) {
return Impl::is_return_nullable(has_nullable, cols_info);
}
return has_nullable;
}
};

using FunctionPtr = std::shared_ptr<IFunction>;

/** Return ColumnNullable of src, with null map as OR-ed null maps of args columns in blocks.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ using FunctionSecToTime = FunctionCurrentDateOrDateTime<SecToTimeImpl>;
using FunctionMicroSecToDateTime = TimestampToDateTime<MicroSec>;
using FunctionMilliSecToDateTime = TimestampToDateTime<MilliSec>;
using FunctionSecToDateTime = TimestampToDateTime<Sec>;
using FunctionPeriodAdd = FunctionNeedsToHandleNull<PeriodAddImpl, PrimitiveType::TYPE_BIGINT>;
using FunctionPeriodDiff = FunctionNeedsToHandleNull<PeriodDiffImpl, PrimitiveType::TYPE_BIGINT>;

void register_function_date_time_computation(SimpleFunctionFactory& factory) {
factory.register_function<FunctionDateDiff>();
Expand Down Expand Up @@ -94,6 +96,8 @@ void register_function_date_time_computation(SimpleFunctionFactory& factory) {
factory.register_function<FunctionMonthsBetween>();
factory.register_function<FunctionTime>();
factory.register_function<FunctionGetFormat>();
factory.register_function<FunctionPeriodAdd>();
factory.register_function<FunctionPeriodDiff>();

// alias
factory.register_alias("days_add", "date_add");
Expand Down
89 changes: 89 additions & 0 deletions be/src/vec/functions/function_date_or_datetime_computation.h
Original file line number Diff line number Diff line change
Expand Up @@ -1378,5 +1378,94 @@ class FunctionGetFormat : public IFunction {
static constexpr auto TIME_NAME = "TIME";
};

class PeriodHelper {
public:
// For two digit year, 70-99 -> 1970-1999, 00-69 -> 2000-2069
// this rule is same as MySQL
static constexpr int YY_PART_YEAR = 70;
static Status valid_period(int64_t period) {
if (period <= 0 || (period % 100) == 0 || (period % 100) > 12) {
return Status::InvalidArgument("Period function got invalid period: {}", period);
}
return Status::OK();
}

static int64_t check_and_convert_period_to_month(uint64_t period) {
THROW_IF_ERROR(valid_period(period));
uint64_t year = period / 100;
if (year < 100) {
year += (year >= YY_PART_YEAR) ? 1900 : 2000;
}
return year * 12LL + (period % 100) - 1;
}

static int64_t convert_month_to_period(uint64_t month) {
uint64_t year = month / 12;
if (year < 100) {
year += (year >= YY_PART_YEAR) ? 1900 : 2000;
}
return year * 100 + month % 12 + 1;
}
};

class PeriodAddImpl {
public:
static constexpr auto name = "period_add";
static size_t get_number_of_arguments() { return 2; }
static DataTypePtr get_return_type_impl(const DataTypes& arguments) {
return std::make_shared<DataTypeInt64>();
}

static void execute(const std::vector<ColumnsWithConstAndNullMap>& cols_info,
ColumnInt64::MutablePtr& res_col, PaddedPODArray<UInt8>& res_null_map_data,
size_t input_rows_count) {
const auto& left_data =
assert_cast<const ColumnInt64*>(cols_info[0].nested_col)->get_data();
const auto& right_data =
assert_cast<const ColumnInt64*>(cols_info[1].nested_col)->get_data();
for (size_t i = 0; i < input_rows_count; ++i) {
if (cols_info[0].is_null_at(i) || cols_info[1].is_null_at(i)) {
res_col->insert_default();
res_null_map_data[i] = 1;
continue;
}

int64_t period = left_data[index_check_const(i, cols_info[0].is_const)];
int64_t months = right_data[index_check_const(i, cols_info[1].is_const)];
res_col->insert_value(PeriodHelper::convert_month_to_period(
PeriodHelper::check_and_convert_period_to_month(period) + months));
}
}
};
class PeriodDiffImpl {
public:
static constexpr auto name = "period_diff";
static size_t get_number_of_arguments() { return 2; }
static DataTypePtr get_return_type_impl(const DataTypes& arguments) {
return std::make_shared<DataTypeInt64>();
}

static void execute(const std::vector<ColumnsWithConstAndNullMap>& cols_info,
ColumnInt64::MutablePtr& res_col, PaddedPODArray<UInt8>& res_null_map_data,
size_t input_rows_count) {
const auto& left_data =
assert_cast<const ColumnInt64*>(cols_info[0].nested_col)->get_data();
const auto& right_data =
assert_cast<const ColumnInt64*>(cols_info[1].nested_col)->get_data();
for (size_t i = 0; i < input_rows_count; ++i) {
if (cols_info[0].is_null_at(i) || cols_info[1].is_null_at(i)) {
res_col->insert_default();
res_null_map_data[i] = 1;
continue;
}

int64_t period1 = left_data[index_check_const(i, cols_info[0].is_const)];
int64_t period2 = right_data[index_check_const(i, cols_info[1].is_const)];
res_col->insert_value(PeriodHelper::check_and_convert_period_to_month(period1) -
PeriodHelper::check_and_convert_period_to_month(period2));
}
}
};

#include "common/compile_check_avoid_end.h"
} // namespace doris::vectorized
2 changes: 2 additions & 0 deletions be/src/vec/functions/function_string.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1346,6 +1346,8 @@ using FunctionStringAppendTrailingCharIfAbsent =
using FunctionStringLPad = FunctionStringPad<StringLPad>;
using FunctionStringRPad = FunctionStringPad<StringRPad>;

using FunctionMakeSet = FunctionNeedsToHandleNull<MakeSetImpl, PrimitiveType::TYPE_STRING>;

void register_function_string(SimpleFunctionFactory& factory) {
factory.register_function<FunctionStringParseDataSize>();
factory.register_function<FunctionStringASCII>();
Expand Down
Loading