Skip to content

Commit 1499889

Browse files
Add correlation distance algorithm (#3131)
* Adds correlation distance algorithm, examples, docs, and unit tests Signed-off-by: North Iii <[email protected]> Co-authored-by: Victoriya Fedotova <[email protected]>
1 parent 97971aa commit 1499889

35 files changed

+2574
-6
lines changed

cpp/daal/src/algorithms/cordistance/cordistance_batch_impl.i

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,13 +56,25 @@ template <typename algorithmFPType, Method method, CpuType cpu>
5656
services::Status DistanceKernel<algorithmFPType, method, cpu>::compute(const size_t na, const NumericTable * const * a, const size_t nr,
5757
NumericTable * r[], const daal::algorithms::Parameter * par)
5858
{
59-
NumericTable * xTable = const_cast<NumericTable *>(a[0]); /* Input data */
59+
NumericTable * xTable = const_cast<NumericTable *>(a[0]); /* x Input data */
6060
NumericTable * rTable = const_cast<NumericTable *>(r[0]); /* Result */
6161
const NumericTableIface::StorageLayout rLayout = r[0]->getDataLayout();
6262

6363
if (isFull<algorithmFPType, cpu>(rLayout))
6464
{
65-
return corDistanceFull<algorithmFPType, cpu>(xTable, rTable);
65+
if (na == 1)
66+
{
67+
return corDistanceFull<algorithmFPType, cpu>(xTable, rTable);
68+
}
69+
else if (na == 2)
70+
{
71+
NumericTable * yTable = const_cast<NumericTable *>(a[1]); /* y Input data */
72+
return corDistanceFull<algorithmFPType, cpu>(xTable, yTable, rTable);
73+
}
74+
else
75+
{
76+
return services::Status(services::ErrorIncorrectNumberOfInputNumericTables);
77+
}
6678
}
6779
else
6880
{

cpp/daal/src/algorithms/cordistance/cordistance_dense_default_batch_fpt_cpu.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ template class BatchContainer<DAAL_FPTYPE, defaultDense, DAAL_CPU>;
3838
}
3939
namespace internal
4040
{
41-
template class DistanceKernel<DAAL_FPTYPE, defaultDense, DAAL_CPU>;
41+
template class DAAL_EXPORT DistanceKernel<DAAL_FPTYPE, defaultDense, DAAL_CPU>;
4242

4343
} // namespace internal
4444

cpp/daal/src/algorithms/cordistance/cordistance_full_impl.i

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
*/
2323
#include "src/services/service_defines.h"
2424
using namespace daal::internal;
25+
using namespace daal::services;
2526

2627
namespace daal
2728
{
@@ -306,6 +307,139 @@ services::Status corDistanceFull(const NumericTable * xTable, NumericTable * rTa
306307
return safeStat.detach();
307308
}
308309

310+
template <typename algorithmFPType, CpuType cpu>
311+
services::Status corDistanceFull(const NumericTable * xTable, const NumericTable * yTable, NumericTable * rTable)
312+
{
313+
size_t p = xTable->getNumberOfColumns(); /* Dimension of input feature vector */
314+
size_t nVectors1 = xTable->getNumberOfRows(); /* Number of input vectors in X */
315+
size_t nVectors2 = yTable->getNumberOfRows(); /* Number of input vectors in Y */
316+
317+
size_t nBlocks1 = nVectors1 / blockSizeDefault;
318+
nBlocks1 += (nBlocks1 * blockSizeDefault != nVectors1);
319+
320+
size_t nBlocks2 = nVectors2 / blockSizeDefault;
321+
nBlocks2 += (nBlocks2 * blockSizeDefault != nVectors2);
322+
323+
SafeStatus safeStat;
324+
325+
/* Allocate yMean for all Y vectors before the loop */
326+
TArray<algorithmFPType, cpu> yMeanArr(nVectors2);
327+
algorithmFPType * yMean = yMeanArr.get();
328+
DAAL_CHECK(yMean, ErrorMemoryAllocationFailed);
329+
330+
/* Compute means for all Y vectors before the loop */
331+
for (size_t k2 = 0; k2 < nBlocks2; k2++)
332+
{
333+
DAAL_INT blockSize2 = blockSizeDefault;
334+
if (k2 == nBlocks2 - 1)
335+
{
336+
blockSize2 = nVectors2 - k2 * blockSizeDefault;
337+
}
338+
339+
size_t shift2 = k2 * blockSizeDefault;
340+
341+
/* read access to blockSize2 rows in input dataset Y */
342+
ReadRows<algorithmFPType, cpu> yBlock(*const_cast<NumericTable *>(yTable), shift2, blockSize2);
343+
DAAL_CHECK_BLOCK_STATUS(yBlock);
344+
const algorithmFPType * y = yBlock.get();
345+
346+
for (size_t j = 0; j < blockSize2; j++)
347+
{
348+
yMean[shift2 + j] = 0.0;
349+
for (size_t k = 0; k < p; k++)
350+
{
351+
yMean[shift2 + j] += y[j * p + k];
352+
}
353+
yMean[shift2 + j] /= p;
354+
}
355+
}
356+
357+
/* compute results for blocks of the distance matrix */
358+
daal::threader_for(nBlocks1, nBlocks1, [=, &safeStat](size_t k1) {
359+
DAAL_INT blockSize1 = blockSizeDefault;
360+
if (k1 == nBlocks1 - 1)
361+
{
362+
blockSize1 = nVectors1 - k1 * blockSizeDefault;
363+
}
364+
365+
/* read access to blockSize1 rows in input dataset X at k1*blockSizeDefault*p row */
366+
ReadRows<algorithmFPType, cpu> xBlock(*const_cast<NumericTable *>(xTable), k1 * blockSizeDefault, blockSize1);
367+
DAAL_CHECK_BLOCK_STATUS_THR(xBlock);
368+
const algorithmFPType * x = xBlock.get();
369+
370+
/* write access to blockSize1 rows in output dataset */
371+
WriteOnlyRows<algorithmFPType, cpu> rBlock(rTable, k1 * blockSizeDefault, blockSize1);
372+
DAAL_CHECK_BLOCK_STATUS_THR(rBlock);
373+
algorithmFPType * r = rBlock.get();
374+
375+
/* Compute means for rows in X block */
376+
TArrayScalable<algorithmFPType, cpu> xMeanArr(blockSize1);
377+
algorithmFPType * xMean = xMeanArr.get();
378+
DAAL_CHECK_THR(xMean, ErrorMemoryAllocationFailed);
379+
380+
for (size_t i = 0; i < blockSize1; i++)
381+
{
382+
xMean[i] = 0.0;
383+
for (size_t j = 0; j < p; j++)
384+
{
385+
xMean[i] += x[i * p + j];
386+
}
387+
xMean[i] /= p;
388+
}
389+
390+
for (size_t k2 = 0; k2 < nBlocks2; k2++)
391+
{
392+
DAAL_INT blockSize2 = blockSizeDefault;
393+
if (k2 == nBlocks2 - 1)
394+
{
395+
blockSize2 = nVectors2 - k2 * blockSizeDefault;
396+
}
397+
398+
size_t shift2 = k2 * blockSizeDefault;
399+
400+
/* read access to blockSize2 rows in input dataset Y */
401+
ReadRows<algorithmFPType, cpu> yBlock(*const_cast<NumericTable *>(yTable), shift2, blockSize2);
402+
DAAL_CHECK_BLOCK_STATUS_THR(yBlock);
403+
const algorithmFPType * y = yBlock.get();
404+
405+
for (size_t i = 0; i < blockSize1; i++)
406+
{
407+
for (size_t j = 0; j < blockSize2; j++)
408+
{
409+
algorithmFPType numerator = 0.0;
410+
algorithmFPType xNorm = 0.0;
411+
algorithmFPType yNorm = 0.0;
412+
413+
for (size_t k = 0; k < p; k++)
414+
{
415+
algorithmFPType x_centered = x[i * p + k] - xMean[i];
416+
algorithmFPType y_centered = y[j * p + k] - yMean[shift2 + j];
417+
418+
numerator += x_centered * y_centered;
419+
xNorm += x_centered * x_centered;
420+
yNorm += y_centered * y_centered;
421+
}
422+
423+
algorithmFPType denominator = xNorm * yNorm;
424+
if (denominator > 0.0)
425+
{
426+
r[i * nVectors2 + shift2 + j] = 1.0
427+
- numerator
428+
/ (daal::internal::MathInst<algorithmFPType, cpu>::sSqrt(xNorm)
429+
* daal::internal::MathInst<algorithmFPType, cpu>::sSqrt(yNorm));
430+
}
431+
else
432+
{
433+
r[i * nVectors2 + shift2 + j] = 1.0; // Maximum distance when no variance
434+
}
435+
}
436+
}
437+
}
438+
});
439+
440+
return safeStat.detach();
441+
}
442+
309443
} // namespace internal
310444

311445
} // namespace correlation_distance

cpp/oneapi/dal/algo/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ ALGOS = [
1616
"basic_statistics",
1717
"chebyshev_distance",
1818
"connected_components",
19+
"correlation_distance",
1920
"cosine_distance",
2021
"dbscan",
2122
"decision_tree",
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
/*******************************************************************************
2+
* Copyright contributors to the oneDAL project
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*******************************************************************************/
16+
17+
#pragma once
18+
19+
#include "oneapi/dal/algo/correlation_distance/compute.hpp"
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
package(default_visibility = ["//visibility:public"])
2+
load("@onedal//dev/bazel:dal.bzl",
3+
"dal_module",
4+
"dal_test_suite",
5+
)
6+
7+
dal_module(
8+
name = "correlation_distance",
9+
auto = True,
10+
dal_deps = [
11+
"@onedal//cpp/oneapi/dal:core",
12+
"@onedal//cpp/oneapi/dal/backend/primitives:blas",
13+
"@onedal//cpp/oneapi/dal/backend/primitives:reduction",
14+
"@onedal//cpp/oneapi/dal/backend/primitives:distance",
15+
],
16+
extra_deps = [
17+
"@onedal//cpp/daal/src/algorithms/cordistance:kernel",
18+
"@onedal//cpp/daal:data_management",
19+
]
20+
)
21+
22+
dal_test_suite(
23+
name = "interface_tests",
24+
framework = "catch2",
25+
srcs = glob([
26+
"test/*.cpp",
27+
]),
28+
dal_deps = [
29+
":correlation_distance",
30+
],
31+
)
32+
33+
dal_test_suite(
34+
name = "tests",
35+
tests = [
36+
":interface_tests",
37+
],
38+
)
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
/*******************************************************************************
2+
* Copyright contributors to the oneDAL project
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*******************************************************************************/
16+
17+
#pragma once
18+
19+
#include "oneapi/dal/algo/correlation_distance/compute_types.hpp"
20+
#include "oneapi/dal/backend/dispatcher.hpp"
21+
#include "oneapi/dal/table/homogen.hpp"
22+
23+
namespace oneapi::dal::correlation_distance::backend {
24+
25+
template <typename Float, typename Method, typename Task>
26+
struct compute_kernel_cpu {
27+
compute_result<Task> operator()(const dal::backend::context_cpu& ctx,
28+
const detail::descriptor_base<Task>& params,
29+
const compute_input<Task>& input) const;
30+
31+
#ifdef ONEDAL_DATA_PARALLEL
32+
void operator()(const dal::backend::context_cpu& ctx,
33+
const detail::descriptor_base<Task>& params,
34+
const table& x,
35+
const table& y,
36+
homogen_table& res) const;
37+
#endif
38+
};
39+
40+
} // namespace oneapi::dal::correlation_distance::backend
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
/*******************************************************************************
2+
* Copyright contributors to the oneDAL project
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*******************************************************************************/
16+
17+
#include <daal/src/algorithms/cordistance/cordistance_kernel.h>
18+
19+
#include "oneapi/dal/algo/correlation_distance/backend/cpu/compute_kernel.hpp"
20+
#include "oneapi/dal/backend/interop/common.hpp"
21+
#include "oneapi/dal/backend/interop/error_converter.hpp"
22+
#include "oneapi/dal/backend/interop/table_conversion.hpp"
23+
#include "oneapi/dal/exceptions.hpp"
24+
25+
#include "oneapi/dal/table/row_accessor.hpp"
26+
27+
namespace oneapi::dal::correlation_distance::backend {
28+
29+
using dal::backend::context_cpu;
30+
using input_t = compute_input<task::compute>;
31+
using result_t = compute_result<task::compute>;
32+
using descriptor_t = detail::descriptor_base<task::compute>;
33+
34+
namespace daal_correlation = daal::algorithms::correlation_distance;
35+
namespace interop = dal::backend::interop;
36+
37+
template <typename Float, daal::CpuType Cpu>
38+
using daal_correlation_t =
39+
daal_correlation::internal::DistanceKernel<Float, daal_correlation::defaultDense, Cpu>;
40+
41+
template <typename Float>
42+
static result_t call_daal_kernel(const context_cpu& ctx,
43+
const descriptor_t& desc,
44+
const table& x,
45+
const table& y) {
46+
const std::int64_t row_count_x = x.get_row_count();
47+
const std::int64_t row_count_y = y.get_row_count();
48+
49+
dal::detail::check_mul_overflow(row_count_x, row_count_y);
50+
auto arr_values = array<Float>::empty(row_count_x * row_count_y);
51+
52+
const auto daal_x = interop::convert_to_daal_table<Float>(x);
53+
const auto daal_y = interop::convert_to_daal_table<Float>(y);
54+
const auto daal_values =
55+
interop::convert_to_daal_homogen_table(arr_values, row_count_x, row_count_y);
56+
57+
daal::algorithms::Parameter param;
58+
const daal::data_management::NumericTable* daal_input_tables[2] = { daal_x.get(),
59+
daal_y.get() };
60+
daal::data_management::NumericTable* daal_result_table[1] = { daal_values.get() };
61+
62+
interop::status_to_exception(
63+
interop::call_daal_kernel<Float, daal_correlation_t>(ctx,
64+
2,
65+
daal_input_tables,
66+
1,
67+
daal_result_table,
68+
&param));
69+
70+
return result_t().set_values(
71+
dal::detail::homogen_table_builder{}.reset(arr_values, row_count_x, row_count_y).build());
72+
}
73+
74+
template <typename Float>
75+
static result_t compute(const context_cpu& ctx, const descriptor_t& desc, const input_t& input) {
76+
return call_daal_kernel<Float>(ctx, desc, input.get_x(), input.get_y());
77+
}
78+
79+
template <typename Float>
80+
struct compute_kernel_cpu<Float, method::dense, task::compute> {
81+
result_t operator()(const context_cpu& ctx,
82+
const descriptor_t& desc,
83+
const input_t& input) const {
84+
return compute<Float>(ctx, desc, input);
85+
}
86+
87+
#ifdef ONEDAL_DATA_PARALLEL
88+
void operator()(const context_cpu& ctx,
89+
const descriptor_t& desc,
90+
const table& x,
91+
const table& y,
92+
homogen_table& res) const {
93+
throw unimplemented(dal::detail::error_messages::method_not_implemented());
94+
}
95+
#endif
96+
};
97+
98+
template struct compute_kernel_cpu<float, method::dense, task::compute>;
99+
template struct compute_kernel_cpu<double, method::dense, task::compute>;
100+
101+
} // namespace oneapi::dal::correlation_distance::backend

0 commit comments

Comments
 (0)