Skip to content

Commit 5ca6a29

Browse files
committed
[Feature](function) Support function DEFAULT
1 parent 2490469 commit 5ca6a29

File tree

10 files changed

+663
-3
lines changed

10 files changed

+663
-3
lines changed

be/src/runtime/descriptors.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,8 @@ SlotDescriptor::SlotDescriptor(const TSlotDescriptor& tdesc)
6464
_is_key(tdesc.is_key),
6565
_column_paths(tdesc.column_paths),
6666
_is_auto_increment(tdesc.__isset.is_auto_increment ? tdesc.is_auto_increment : false),
67-
_col_default_value(tdesc.__isset.col_default_value ? tdesc.col_default_value : "") {
67+
_col_default_value(tdesc.__isset.col_default_value ? tdesc.col_default_value : ""),
68+
_has_default_value(tdesc.__isset.col_default_value) {
6869
if (tdesc.__isset.virtual_column_expr) {
6970
// Make sure virtual column is valid.
7071
if (tdesc.virtual_column_expr.nodes.empty()) {
@@ -98,7 +99,9 @@ SlotDescriptor::SlotDescriptor(const PSlotDescriptor& pdesc)
9899
_is_materialized(pdesc.is_materialized()),
99100
_is_key(pdesc.is_key()),
100101
_column_paths(pdesc.column_paths().begin(), pdesc.column_paths().end()),
101-
_is_auto_increment(pdesc.is_auto_increment()) {}
102+
_is_auto_increment(pdesc.is_auto_increment()),
103+
_col_default_value(),
104+
_has_default_value(false) {}
102105

103106
#ifdef BE_TEST
104107
SlotDescriptor::SlotDescriptor()
@@ -111,7 +114,9 @@ SlotDescriptor::SlotDescriptor()
111114
_field_idx(-1),
112115
_is_materialized(true),
113116
_is_key(false),
114-
_is_auto_increment(false) {}
117+
_is_auto_increment(false),
118+
_col_default_value(),
119+
_has_default_value(false) {}
115120
#endif
116121

117122
void SlotDescriptor::to_protobuf(PSlotDescriptor* pslot) const {

be/src/runtime/descriptors.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ class SlotDescriptor {
8888
bool is_sequence_col() const { return _col_name == SEQUENCE_COL; }
8989

9090
const std::string& col_default_value() const { return _col_default_value; }
91+
bool has_default_value() const { return _has_default_value; }
9192
PrimitiveType col_type() const;
9293

9394
std::shared_ptr<doris::TExpr> get_virtual_column_expr() const {
@@ -131,6 +132,11 @@ class SlotDescriptor {
131132
const bool _is_auto_increment;
132133
const std::string _col_default_value;
133134

135+
// When the default value is NULL, _col_default_value will be initialized to an empty string.
136+
// This parameter is used to distinguish whether the empty string
137+
// of `_col_default_value` is the default value or NULL.
138+
const bool _has_default_value;
139+
134140
std::shared_ptr<doris::TExpr> virtual_column_expr = nullptr;
135141

136142
SlotDescriptor(const TSlotDescriptor& tdesc);
Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
#include <mysql/binary_log_types.h>
19+
20+
#include <string>
21+
22+
#include "common/status.h"
23+
#include "runtime/define_primitive_type.h"
24+
#include "runtime/descriptors.h"
25+
#include "runtime/primitive_type.h"
26+
#include "runtime/runtime_state.h"
27+
#include "util/binary_cast.hpp"
28+
#include "vec/columns/column_const.h"
29+
#include "vec/columns/column_nullable.h"
30+
#include "vec/core/column_with_type_and_name.h"
31+
#include "vec/data_types/data_type_nullable.h"
32+
#include "vec/data_types/serde/data_type_serde.h"
33+
#include "vec/functions/function.h"
34+
#include "vec/functions/simple_function_factory.h"
35+
#include "vec/runtime/vdatetime_value.h"
36+
37+
namespace doris::vectorized {
38+
#include "common/compile_check_begin.h"
39+
40+
class FunctionDefault : public IFunction {
41+
public:
42+
static constexpr auto name = "default";
43+
static FunctionPtr create() { return std::make_shared<FunctionDefault>(); }
44+
String get_name() const override { return name; }
45+
size_t get_number_of_arguments() const override { return 1; }
46+
47+
DataTypePtr get_return_type_impl(const DataTypes& arguments) const override {
48+
return make_nullable(arguments[0]);
49+
}
50+
51+
bool use_default_implementation_for_nulls() const override { return false; }
52+
53+
Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments,
54+
uint32_t result, size_t input_rows_count) const override {
55+
ColumnWithTypeAndName& result_info = block.get_by_position(result);
56+
auto res_nested_type = remove_nullable(result_info.type);
57+
PrimitiveType res_primitive_type = res_nested_type->get_primitive_type();
58+
59+
ColumnWithTypeAndName& input_column_info = block.get_by_position(arguments[0]);
60+
const std::string& col_name = input_column_info.name;
61+
62+
std::string default_value;
63+
bool has_default_value = false;
64+
bool is_nullable = true;
65+
get_default_value_and_nullable_for_col(context, col_name, default_value, has_default_value,
66+
is_nullable);
67+
68+
// For date types, if the default value is `CURRENT_TIMESTAMP` or `CURRENT_DATE`.
69+
// Reference the behavior in MySQL:
70+
// if column is NULLABLE, return all NULLs
71+
// else return zero datetime values like `0000-00-00 00:00:00` or `0000-00-00`
72+
if (is_date_type(res_primitive_type) && has_default_value &&
73+
(default_value == "CURRENT_TIMESTAMP" || default_value == "CURRENT_DATE")) {
74+
if (is_nullable) {
75+
return return_with_all_null(block, result, res_nested_type, input_rows_count);
76+
} else {
77+
return return_with_zero_datetime(block, result, res_nested_type, res_primitive_type,
78+
input_rows_count);
79+
}
80+
}
81+
82+
// For some complex types, only accept NULL as default value
83+
if (is_complex_type(res_primitive_type) || res_primitive_type == TYPE_JSONB ||
84+
res_primitive_type == TYPE_VARIANT) {
85+
if (is_nullable) {
86+
return return_with_all_null(block, result, res_nested_type, input_rows_count);
87+
} else {
88+
return Status::InvalidArgument(
89+
"Column '{}' of type '{}' must be nullable to use DEFAULT", col_name,
90+
res_nested_type->get_name());
91+
}
92+
}
93+
94+
// 1. specified default value when creating table -> default_value
95+
// 2. no specified default value && column is NOT NULL -> error
96+
// 3. no specified default value && column is NULLABLE -> NULL
97+
if (has_default_value) {
98+
MutableColumnPtr res_col = res_nested_type->create_column();
99+
auto null_map = ColumnUInt8::create(input_rows_count, 0);
100+
Field default_field;
101+
102+
auto temp_column = res_nested_type->create_column();
103+
auto serde = res_nested_type->get_serde();
104+
StringRef default_str_ref(default_value.data(), default_value.size());
105+
DataTypeSerDe::FormatOptions options;
106+
Status parse_status = serde->from_string(default_str_ref, *temp_column, options);
107+
108+
if (parse_status.ok() && temp_column->size() > 0) {
109+
temp_column->get(0, default_field);
110+
}
111+
112+
for (size_t i = 0; i < input_rows_count; ++i) {
113+
res_col->insert(default_field);
114+
}
115+
block.replace_by_position(
116+
result, ColumnNullable::create(std::move(res_col), std::move(null_map)));
117+
return Status::OK();
118+
} else {
119+
if (is_nullable) {
120+
return return_with_all_null(block, result, res_nested_type, input_rows_count);
121+
} else {
122+
return Status::InvalidArgument("Column '{}' is NOT NULL but has no default value",
123+
col_name);
124+
}
125+
}
126+
return Status::OK();
127+
}
128+
129+
private:
130+
void get_default_value_and_nullable_for_col(FunctionContext* context,
131+
const std::string& column_name,
132+
std::string& default_value, bool& has_default_value,
133+
bool& is_nullable) const {
134+
RuntimeState* state = context->state();
135+
const DescriptorTbl& desc_tbl = state->desc_tbl();
136+
137+
SlotDescriptor* target_slot = nullptr;
138+
for (auto* tuple_desc : desc_tbl.get_tuple_descs()) {
139+
for (auto* slot : tuple_desc->slots()) {
140+
if (slot->col_name() == column_name) {
141+
target_slot = slot;
142+
break;
143+
}
144+
}
145+
if (target_slot) {
146+
break;
147+
}
148+
}
149+
150+
if (target_slot) {
151+
is_nullable = target_slot->is_nullable();
152+
default_value = target_slot->col_default_value();
153+
has_default_value = target_slot->has_default_value();
154+
}
155+
}
156+
157+
static Status return_with_all_null(Block& block, uint32_t result,
158+
const DataTypePtr& nested_type, size_t input_rows_count) {
159+
MutableColumnPtr res_col = nested_type->create_column();
160+
res_col->insert_many_defaults(input_rows_count);
161+
auto null_map = ColumnUInt8::create(input_rows_count, 1);
162+
block.replace_by_position(result,
163+
ColumnNullable::create(std::move(res_col), std::move(null_map)));
164+
return Status::OK();
165+
}
166+
167+
static Status return_with_zero_datetime(Block& block, uint32_t result,
168+
const DataTypePtr& nested_type,
169+
PrimitiveType primitive_type, size_t input_rows_count) {
170+
MutableColumnPtr res_col = nested_type->create_column();
171+
172+
switch (primitive_type) {
173+
case TYPE_DATE:
174+
case TYPE_DATETIME:
175+
insert_min_datetime_values<TYPE_DATETIME>(res_col, input_rows_count);
176+
break;
177+
case TYPE_DATEV2:
178+
insert_min_datetime_values<TYPE_DATEV2>(res_col, input_rows_count);
179+
break;
180+
case TYPE_DATETIMEV2:
181+
insert_min_datetime_values<TYPE_DATETIMEV2>(res_col, input_rows_count);
182+
break;
183+
default:
184+
return Status::InternalError("Unsupported date/time type for zero datetime: {}",
185+
nested_type->get_name());
186+
}
187+
188+
auto null_map = ColumnUInt8::create(input_rows_count, 0);
189+
block.replace_by_position(result,
190+
ColumnNullable::create(std::move(res_col), std::move(null_map)));
191+
return Status::OK();
192+
}
193+
194+
template <PrimitiveType Type>
195+
static void insert_min_datetime_values(MutableColumnPtr& res_col, size_t count) {
196+
using ItemType = typename PrimitiveTypeTraits<Type>::ColumnItemType;
197+
ItemType min_value;
198+
199+
if constexpr (Type == TYPE_DATE || Type == TYPE_DATETIME) {
200+
min_value =
201+
binary_cast<VecDateTimeValue, ItemType>(VecDateTimeValue::datetime_min_value());
202+
} else if constexpr (Type == TYPE_DATEV2) {
203+
min_value = MIN_DATE_V2;
204+
} else if constexpr (Type == TYPE_DATETIMEV2) {
205+
min_value = MIN_DATETIME_V2;
206+
}
207+
208+
for (size_t i = 0; i < count; ++i) {
209+
res_col->insert_data(reinterpret_cast<const char*>(&min_value), sizeof(ItemType));
210+
}
211+
}
212+
};
213+
#include "common/compile_check_end.h"
214+
215+
void register_function_default(SimpleFunctionFactory& factory) {
216+
factory.register_function<FunctionDefault>();
217+
}
218+
} // namespace doris::vectorized

be/src/vec/functions/simple_function_factory.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ void register_function_ignore(SimpleFunctionFactory& factory);
106106
void register_function_encryption(SimpleFunctionFactory& factory);
107107
void register_function_regexp_extract(SimpleFunctionFactory& factory);
108108
void register_function_hex_variadic(SimpleFunctionFactory& factory);
109+
void register_function_default(SimpleFunctionFactory& factory);
109110
void register_function_match(SimpleFunctionFactory& factory);
110111
void register_function_tokenize(SimpleFunctionFactory& factory);
111112
void register_function_url(SimpleFunctionFactory& factory);
@@ -315,6 +316,7 @@ class SimpleFunctionFactory {
315316
register_function_convert_tz(instance);
316317
register_function_least_greast(instance);
317318
register_function_fake(instance);
319+
register_function_default(instance);
318320
register_function_encryption(instance);
319321
register_function_regexp_extract(instance);
320322
register_function_hex_variadic(instance);

fe/fe-core/src/main/antlr4/org/apache/doris/nereids/DorisParser.g4

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1974,6 +1974,7 @@ nonReserved
19741974
| DECIMAL
19751975
| DECIMALV2
19761976
| DECIMALV3
1977+
| DEFAULT
19771978
| DEFERRED
19781979
| DEMAND
19791980
| DIAGNOSE

fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinScalarFunctions.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,7 @@
175175
import org.apache.doris.nereids.trees.expressions.functions.scalar.DaysSub;
176176
import org.apache.doris.nereids.trees.expressions.functions.scalar.Dceil;
177177
import org.apache.doris.nereids.trees.expressions.functions.scalar.DecodeAsVarchar;
178+
import org.apache.doris.nereids.trees.expressions.functions.scalar.Default;
178179
import org.apache.doris.nereids.trees.expressions.functions.scalar.Degrees;
179180
import org.apache.doris.nereids.trees.expressions.functions.scalar.Dexp;
180181
import org.apache.doris.nereids.trees.expressions.functions.scalar.Dfloor;
@@ -705,6 +706,7 @@ public class BuiltinScalarFunctions implements FunctionHelper {
705706
scalar(DaysSub.class, "days_sub", "date_sub", "subdate"),
706707
scalar(Dceil.class, "dceil"),
707708
scalar(DecodeAsVarchar.class, "decode_as_varchar"),
709+
scalar(Default.class, "default"),
708710
scalar(Degrees.class, "degrees"),
709711
scalar(Dexp.class, "dexp"),
710712
scalar(Dfloor.class, "dfloor"),
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
package org.apache.doris.nereids.trees.expressions.functions.scalar;
19+
20+
import org.apache.doris.catalog.FunctionSignature;
21+
import org.apache.doris.nereids.exceptions.AnalysisException;
22+
import org.apache.doris.nereids.trees.expressions.Expression;
23+
import org.apache.doris.nereids.trees.expressions.SlotReference;
24+
import org.apache.doris.nereids.trees.expressions.functions.AlwaysNullable;
25+
import org.apache.doris.nereids.trees.expressions.functions.CustomSignature;
26+
import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression;
27+
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
28+
import org.apache.doris.nereids.types.DataType;
29+
30+
import com.google.common.base.Preconditions;
31+
32+
import java.util.List;
33+
34+
/**
35+
* ScalarFunction 'default'. This function returns the default value of a column.
36+
*/
37+
public class Default extends ScalarFunction
38+
implements UnaryExpression, CustomSignature, AlwaysNullable {
39+
40+
/**
41+
* constructor with 1 argument.
42+
*/
43+
public Default(Expression arg) {
44+
super("default", arg);
45+
}
46+
47+
/** constructor for withChildren and reuse signature */
48+
private Default(ScalarFunctionParams functionParams) {
49+
super(functionParams);
50+
}
51+
52+
/**
53+
* withChildren.
54+
*/
55+
@Override
56+
public Default withChildren(List<Expression> children) {
57+
Preconditions.checkArgument(children.size() == 1);
58+
return new Default(getFunctionParams(children));
59+
}
60+
61+
@Override
62+
public FunctionSignature customSignature() {
63+
// Return signature that accepts any type and returns the same type (but nullable)
64+
DataType argType = getArgumentType(0);
65+
return FunctionSignature.ret(argType).args(argType);
66+
}
67+
68+
@Override
69+
public void checkLegalityAfterRewrite() {
70+
Expression arg = getArgument(0);
71+
if (!(arg instanceof SlotReference)) {
72+
throw new AnalysisException("DEFAULT function requires a column reference, not a constant or expression");
73+
}
74+
}
75+
76+
@Override
77+
public void checkLegalityBeforeTypeCoercion() {
78+
checkLegalityAfterRewrite();
79+
}
80+
81+
@Override
82+
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
83+
return visitor.visitDefault(this, context);
84+
}
85+
}

0 commit comments

Comments
 (0)