22// Licensed under the MIT License.
33
44#include " core/framework/tensor.h"
5+ #include " core/mlas/inc/mlas.h"
56#include " core/util/math_cpuonly.h"
67#include " core/providers/common.h"
78#include " core/platform/threadpool.h"
@@ -36,52 +37,188 @@ REGISTER_KERNEL_TYPED(float)
3637REGISTER_KERNEL_TYPED(double )
3738REGISTER_KERNEL_TYPED(MLFloat16)
3839
39- // Utility to convert from MLFloat16 to float only when the input type is MLFloat16.
40- template <typename T, typename Ret>
41- ORT_FORCEINLINE Ret ConvertMLFloat16ToDoubleOrFloatIfNeeded(T val);
42-
43- template <>
44- ORT_FORCEINLINE float ConvertMLFloat16ToDoubleOrFloatIfNeeded<MLFloat16, float >(MLFloat16 val) {
45- return val.ToFloat ();
46- }
47-
48- template <>
49- ORT_FORCEINLINE double ConvertMLFloat16ToDoubleOrFloatIfNeeded<MLFloat16, double >(MLFloat16 val) {
50- return static_cast <double >(ConvertMLFloat16ToDoubleOrFloatIfNeeded<MLFloat16, float >(val));
40+ namespace {
41+
42+ template <typename T, typename = std::enable_if_t <std::is_same_v<T, float > || std::is_same_v<T, double >, void >>
43+ void ComputeJob (
44+ const T* input_data,
45+ const T* skip_data,
46+ const T* gamma_data,
47+ const T* beta_data,
48+ const T* bias_data,
49+ IAllocatorUniquePtr<float >& skip_float_uptr,
50+ IAllocatorUniquePtr<float >& gamma_float_uptr,
51+ IAllocatorUniquePtr<float >& beta_float_uptr,
52+ IAllocatorUniquePtr<float >& bias_float_uptr,
53+ ptrdiff_t task_idx,
54+ int hidden_size,
55+ int64_t skip_size,
56+ float epsilon,
57+ bool simplified,
58+ T* output_data,
59+ T* skip_input_bias_add_output_data,
60+ AllocatorPtr alloc) {
61+ ORT_UNUSED_PARAMETER (skip_float_uptr); // only used in MLFloat16 overload
62+ ORT_UNUSED_PARAMETER (gamma_float_uptr); // only used in MLFloat16 overload
63+ ORT_UNUSED_PARAMETER (beta_float_uptr); // only used in MLFloat16 overload
64+ ORT_UNUSED_PARAMETER (bias_float_uptr); // only used in MLFloat16 overload
65+ ORT_UNUSED_PARAMETER (alloc);
66+
67+ auto offset = task_idx * hidden_size;
68+ const T* p_input = input_data + offset;
69+ const T* p_skip = skip_data + (offset % skip_size);
70+ T* p_output = output_data + offset;
71+ T* p_skip_input_bias_add_output = skip_input_bias_add_output_data == nullptr ? nullptr : skip_input_bias_add_output_data + offset;
72+
73+ T mean (0 .0f );
74+ T mean_square (0 .0f );
75+
76+ for (decltype (hidden_size) h = 0 ; h < hidden_size; h++) {
77+ T val = p_input[h] + p_skip[h];
78+
79+ if (nullptr != bias_data) {
80+ val += bias_data[h];
81+ }
82+
83+ if (nullptr != p_skip_input_bias_add_output) {
84+ p_skip_input_bias_add_output[h] = val;
85+ }
86+
87+ p_output[h] = val;
88+ mean += val;
89+ mean_square += val * val;
90+ }
91+
92+ mean = mean / hidden_size;
93+ if (simplified) {
94+ mean_square = sqrt (mean_square / hidden_size + epsilon);
95+ } else {
96+ mean_square = sqrt (mean_square / hidden_size - mean * mean + epsilon);
97+ }
98+
99+ for (decltype (hidden_size) h = 0 ; h < hidden_size; h++) {
100+ if (simplified) {
101+ p_output[h] = p_output[h] / mean_square * gamma_data[h];
102+ } else if (nullptr == beta_data) {
103+ p_output[h] = (p_output[h] - mean) / mean_square * gamma_data[h];
104+ } else {
105+ p_output[h] = (p_output[h] - mean) / mean_square * gamma_data[h] + beta_data[h];
106+ }
107+ }
51108}
52109
53- template <>
54- ORT_FORCEINLINE constexpr float ConvertMLFloat16ToDoubleOrFloatIfNeeded<float , float >(float val) {
55- return val;
110+ void ComputeJob (
111+ const MLFloat16* input_data,
112+ const MLFloat16* skip_data,
113+ const MLFloat16* gamma_data,
114+ const MLFloat16* beta_data,
115+ const MLFloat16* bias_data,
116+ IAllocatorUniquePtr<float >& skip_float_uptr,
117+ IAllocatorUniquePtr<float >& gamma_float_uptr,
118+ IAllocatorUniquePtr<float >& beta_float_uptr,
119+ IAllocatorUniquePtr<float >& bias_float_uptr,
120+ ptrdiff_t task_idx,
121+ int hidden_size,
122+ int64_t skip_size,
123+ float epsilon,
124+ bool simplified,
125+ MLFloat16* output_data,
126+ MLFloat16* skip_input_bias_add_output_data,
127+ AllocatorPtr alloc) {
128+ auto offset = task_idx * hidden_size;
129+ const MLFloat16* p_input = input_data + offset;
130+ const MLFloat16* p_skip = skip_data + (offset % skip_size);
131+ MLFloat16* p_output = output_data + offset;
132+ MLFloat16* p_skip_input_bias_add_output = skip_input_bias_add_output_data == nullptr ? nullptr : skip_input_bias_add_output_data + offset;
133+
134+ float mean (0 .0f );
135+ float mean_square (0 .0f );
136+ const size_t num_elems = static_cast <size_t >(hidden_size);
137+
138+ IAllocatorUniquePtr<float > input_float_uptr = IAllocator::MakeUniquePtr<float >(alloc, num_elems);
139+ MlasConvertHalfToFloatBuffer (p_input, input_float_uptr.get (), num_elems);
140+
141+ if (!skip_float_uptr) {
142+ skip_float_uptr = IAllocator::MakeUniquePtr<float >(alloc, num_elems);
143+ MlasConvertHalfToFloatBuffer (p_skip, skip_float_uptr.get (), num_elems);
144+ }
145+
146+ if (bias_data && !bias_float_uptr) {
147+ bias_float_uptr = IAllocator::MakeUniquePtr<float >(alloc, num_elems);
148+ MlasConvertHalfToFloatBuffer (bias_data, bias_float_uptr.get (), num_elems);
149+ }
150+
151+ IAllocatorUniquePtr<float > output_float_uptr = IAllocator::MakeUniquePtr<float >(alloc, num_elems);
152+ float * output_float_ptr = output_float_uptr.get ();
153+
154+ const float * input_float_ptr = input_float_uptr.get ();
155+ const float * skip_float_ptr = skip_float_uptr.get ();
156+ const float * bias_float_ptr = bias_float_uptr.get ();
157+ for (size_t h = 0 ; h < num_elems; h++) {
158+ float val = input_float_ptr[h] + skip_float_ptr[h];
159+
160+ if (bias_float_uptr) {
161+ val += bias_float_ptr[h];
162+ }
163+
164+ output_float_ptr[h] = val;
165+ mean += val;
166+ mean_square += val * val;
167+ }
168+
169+ if (nullptr != p_skip_input_bias_add_output) {
170+ MlasConvertFloatToHalfBuffer (output_float_ptr, p_skip_input_bias_add_output, num_elems);
171+ }
172+
173+ mean = mean / hidden_size;
174+ if (simplified) {
175+ mean_square = sqrt (mean_square / hidden_size + epsilon);
176+ } else {
177+ mean_square = sqrt (mean_square / hidden_size - mean * mean + epsilon);
178+ }
179+
180+ if (!gamma_float_uptr) {
181+ gamma_float_uptr = std::move (input_float_uptr); // overwrite input with gamma values, since they have the same size
182+ MlasConvertHalfToFloatBuffer (gamma_data, gamma_float_uptr.get (), num_elems);
183+ }
184+
185+ if (beta_data && !beta_float_uptr) {
186+ beta_float_uptr = IAllocator::MakeUniquePtr<float >(alloc, num_elems);
187+ MlasConvertHalfToFloatBuffer (beta_data, beta_float_uptr.get (), num_elems);
188+ }
189+
190+ const float * gamma_float_ptr = gamma_float_uptr.get ();
191+ const float * beta_float_ptr = beta_float_uptr.get ();
192+ for (size_t h = 0 ; h < num_elems; h++) {
193+ if (simplified) {
194+ output_float_ptr[h] = output_float_ptr[h] / mean_square * gamma_float_ptr[h];
195+ } else if (nullptr == beta_float_uptr) {
196+ output_float_ptr[h] = (output_float_ptr[h] - mean) / mean_square * gamma_float_ptr[h];
197+ } else {
198+ output_float_ptr[h] = (output_float_ptr[h] - mean) / mean_square * gamma_float_ptr[h] + beta_float_ptr[h];
199+ }
200+ }
201+
202+ MlasConvertFloatToHalfBuffer (output_float_ptr, p_output, num_elems);
56203}
57204
58- template <>
59- ORT_FORCEINLINE constexpr double ConvertMLFloat16ToDoubleOrFloatIfNeeded<double , double >(double val) {
60- return val;
61- }
205+ void ConvertMLFloat16ToFloatIfNeeded (const Tensor& tensor, AllocatorPtr alloc, IAllocatorUniquePtr<float >& dest, bool & is_packed) {
206+ if (tensor.GetElementType () == utils::ToTensorProtoElementType<MLFloat16>()) {
207+ auto tensor_data_ptr = tensor.Data <MLFloat16>();
208+ auto tensor_size = static_cast <size_t >(tensor.Shape ().Size ());
209+ auto float_ptr = IAllocator::MakeUniquePtr<float >(alloc, tensor_size, true );
62210
63- // Function template that only converts the input value to MLFloat16 if T is MLFloat16.
64- template <typename T>
65- ORT_FORCEINLINE constexpr typename std::enable_if_t <std::is_same_v<T, float > || std::is_same_v<T, double >, T>
66- ConvertDoubleOrFloatToMLFloat16IfNeeded (T val) {
67- return val;
211+ MlasConvertHalfToFloatBuffer (tensor_data_ptr, float_ptr.get (), tensor_size);
212+ dest = std::move (float_ptr);
213+ is_packed = true ;
214+ }
68215}
69216
70- template <typename T>
71- ORT_FORCEINLINE constexpr typename std::enable_if_t <std::is_same_v<T, MLFloat16>, T>
72- ConvertDoubleOrFloatToMLFloat16IfNeeded (float val) {
73- return MLFloat16 (val);
74- }
75-
76- template <typename T>
77- ORT_FORCEINLINE constexpr typename std::enable_if_t <std::is_same_v<T, MLFloat16>, T>
78- ConvertDoubleOrFloatToMLFloat16IfNeeded (double val) {
79- return MLFloat16 (static_cast <float >(val));
80- }
217+ } // namespace
81218
82219template <typename T, bool simplified>
83220SkipLayerNorm<T, simplified>::SkipLayerNorm(const OpKernelInfo& op_kernel_info)
84- : OpKernel(op_kernel_info) {
221+ : OpKernel(op_kernel_info), skip_fp32_( nullptr ), gamma_fp32_( nullptr ), beta_fp32_( nullptr ), bias_fp32_( nullptr ) {
85222 ORT_ENFORCE (op_kernel_info.GetAttr <float >(" epsilon" , &epsilon_).IsOK ());
86223 ORT_ENFORCE (epsilon_ >= 0 );
87224}
@@ -94,8 +231,7 @@ Status SkipLayerNorm<T, simplified>::Compute(OpKernelContext* p_ctx) const {
94231 const Tensor* beta = p_ctx->Input <Tensor>(3 );
95232 const Tensor* bias = p_ctx->Input <Tensor>(4 );
96233 Tensor* output = p_ctx->Output (0 , input->Shape ());
97- // For inferencing, we support one more optional output which is the sum
98- // of the input and skip tensors
234+ // For inferencing, we support one more optional output which is the sum of the input and skip tensors
99235 Tensor* skip_input_bias_add_output = p_ctx->Output (3 , input->Shape ());
100236
101237 const auto & input_dims = input->Shape ().GetDims ();
@@ -120,75 +256,44 @@ Status SkipLayerNorm<T, simplified>::Compute(OpKernelContext* p_ctx) const {
120256
121257 T* output_data = output->MutableData <T>();
122258
123- // For inferencing, we support one more optional output which is the sum
124- // of the input and skip tensors
125- T* skip_input_bias_add_output_data = skip_input_bias_add_output != nullptr ? skip_input_bias_add_output->MutableData <T>() : nullptr ;
259+ // For inferencing, we support one more optional output which is the sum of the input and skip tensors
260+ T* skip_input_bias_add_output_data = skip_input_bias_add_output == nullptr ? nullptr : skip_input_bias_add_output->MutableData <T>();
126261
127- const auto & skip_size = skip->Shape ().Size ();
262+ const int64_t & skip_size = skip->Shape ().Size ();
263+
264+ AllocatorPtr alloc;
265+ ORT_RETURN_IF_ERROR (p_ctx->GetTempSpaceAllocator (&alloc));
128266
129267 concurrency::ThreadPool::TryBatchParallelFor (
130268 p_ctx->GetOperatorThreadPool (), static_cast <int32_t >(task_count),
131269 [&](ptrdiff_t task_idx) {
132- auto offset = task_idx * hidden_size;
133-
134- const T* p_input = input_data + offset;
135- const T* p_skip = skip_data + (offset % skip_size);
136- T* p_output = output_data + offset;
137- T* p_skip_input_bias_add_output_data = skip_input_bias_add_output_data != nullptr ? skip_input_bias_add_output_data + offset : nullptr ;
138-
139- using DoubleOrFloat = typename std::conditional<
140- std::is_same<T, double >::value, // If T is double
141- double , // Use double
142- float // Otherwise, use float (covers float and MLFloat16)
143- >::type;
144-
145- DoubleOrFloat mean (0 .0f );
146- DoubleOrFloat mean_square (0 .0f );
147-
148- std::unique_ptr<DoubleOrFloat[]> output_buffer = std::make_unique<DoubleOrFloat[]>(hidden_size);
149- for (size_t h = 0 ; h < static_cast <size_t >(hidden_size); h++) {
150- DoubleOrFloat input_value = ConvertMLFloat16ToDoubleOrFloatIfNeeded<T, DoubleOrFloat>(p_input[h]);
151- DoubleOrFloat skip_value = ConvertMLFloat16ToDoubleOrFloatIfNeeded<T, DoubleOrFloat>(p_skip[h]);
152-
153- DoubleOrFloat value = input_value + skip_value;
154-
155- if (nullptr != bias_data) {
156- value += ConvertMLFloat16ToDoubleOrFloatIfNeeded<T, DoubleOrFloat>(bias_data[h]);
157- }
158-
159- output_buffer[h] = value;
160- T converted_value = ConvertDoubleOrFloatToMLFloat16IfNeeded<T>(value);
161- if (nullptr != p_skip_input_bias_add_output_data) {
162- p_skip_input_bias_add_output_data[h] = converted_value;
163- }
164-
165- mean += value;
166- mean_square += value * value;
167- }
168-
169- mean = mean / hidden_size;
170- if (simplified) {
171- mean_square = sqrt (mean_square / hidden_size + epsilon_);
172- } else {
173- mean_square = sqrt (mean_square / hidden_size - mean * mean + epsilon_);
174- }
175-
176- for (size_t h = 0 ; h < static_cast <size_t >(hidden_size); h++) {
177- DoubleOrFloat gamma_value = ConvertMLFloat16ToDoubleOrFloatIfNeeded<T, DoubleOrFloat>(gamma_data[h]);
178- if (simplified) {
179- p_output[h] = ConvertDoubleOrFloatToMLFloat16IfNeeded<T>(output_buffer[h] / mean_square * gamma_value);
180- } else if (nullptr == beta_data) {
181- p_output[h] = ConvertDoubleOrFloatToMLFloat16IfNeeded<T>((output_buffer[h] - mean) / mean_square * gamma_value);
182- } else {
183- DoubleOrFloat beta_value = ConvertMLFloat16ToDoubleOrFloatIfNeeded<T, DoubleOrFloat>(beta_data[h]);
184- p_output[h] = ConvertDoubleOrFloatToMLFloat16IfNeeded<T>((output_buffer[h] - mean) / mean_square * gamma_value + beta_value);
185- }
186- }
270+ ComputeJob (input_data, skip_data, gamma_data, beta_data, bias_data, skip_fp32_, gamma_fp32_, beta_fp32_,
271+ bias_fp32_, task_idx, hidden_size, skip_size, epsilon_, simplified, output_data,
272+ skip_input_bias_add_output_data, alloc);
187273 },
188274 0 );
189275
190276 return Status::OK ();
191277}
192278
279+ template <typename T, bool simplified>
280+ Status SkipLayerNorm<T, simplified>::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc,
281+ bool & is_packed, PrePackedWeights* prepacked_weights) {
282+ ORT_UNUSED_PARAMETER (prepacked_weights);
283+
284+ is_packed = false ;
285+ if (input_idx == 1 ) { // skip
286+ ConvertMLFloat16ToFloatIfNeeded (tensor, alloc, skip_fp32_, is_packed);
287+ } else if (input_idx == 2 ) { // gamma
288+ ConvertMLFloat16ToFloatIfNeeded (tensor, alloc, gamma_fp32_, is_packed);
289+ } else if (input_idx == 3 ) { // beta
290+ ConvertMLFloat16ToFloatIfNeeded (tensor, alloc, beta_fp32_, is_packed);
291+ } else if (input_idx == 4 ) { // bias
292+ ConvertMLFloat16ToFloatIfNeeded (tensor, alloc, bias_fp32_, is_packed);
293+ }
294+
295+ return Status::OK ();
296+ }
297+
193298} // namespace contrib
194299} // namespace onnxruntime
0 commit comments