Skip to content

Commit d936366

Browse files
[oneDPL][Tests] Apply std::invoke to comparator and projection function calls (#2509)
1 parent 479c26c commit d936366

File tree

13 files changed

+83
-65
lines changed

13 files changed

+83
-65
lines changed

include/oneapi/dpl/pstl/algorithm_impl.h

Lines changed: 36 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
#include "parallel_backend.h"
3434
#include "parallel_impl.h"
3535
#include "iterator_impl.h"
36-
#include "functional_impl.h" // for oneapi::dpl::identity
36+
#include "functional_impl.h" // for oneapi::dpl::identity, std::invoke
3737

3838
#if _ONEDPL_HETERO_BACKEND
3939
# include "hetero/algorithm_impl_hetero.h" // for __pattern_fill_n, __pattern_generate_n
@@ -2442,7 +2442,9 @@ __pattern_sort_by_key(_Tag, _ExecutionPolicy&&, _RandomAccessIterator1 __keys_fi
24422442

24432443
auto __beg = oneapi::dpl::make_zip_iterator(__keys_first, __values_first);
24442444
auto __end = __beg + (__keys_last - __keys_first);
2445-
auto __cmp_f = [__comp](const auto& __a, const auto& __b) { return __comp(std::get<0>(__a), std::get<0>(__b)); };
2445+
auto __cmp_f = [__comp](const auto& __a, const auto& __b) {
2446+
return std::invoke(__comp, std::get<0>(__a), std::get<0>(__b));
2447+
};
24462448

24472449
__leaf_sort(__beg, __end, __cmp_f);
24482450
}
@@ -2456,7 +2458,9 @@ __pattern_sort_by_key(__parallel_tag<_IsVector>, _ExecutionPolicy&& __exec, _Ran
24562458
{
24572459
auto __beg = oneapi::dpl::make_zip_iterator(__keys_first, __values_first);
24582460
auto __end = __beg + (__keys_last - __keys_first);
2459-
auto __cmp_f = [__comp](const auto& __a, const auto& __b) { return __comp(std::get<0>(__a), std::get<0>(__b)); };
2461+
auto __cmp_f = [__comp](const auto& __a, const auto& __b) {
2462+
return std::invoke(__comp, std::get<0>(__a), std::get<0>(__b));
2463+
};
24602464

24612465
using __backend_tag = typename __parallel_tag<_IsVector>::__backend_tag;
24622466

@@ -2699,8 +2703,9 @@ __pattern_nth_element(__parallel_tag<_IsVector> __tag, _ExecutionPolicy&& __exec
26992703
_RandomAccessIterator __x;
27002704
do
27012705
{
2702-
__x = __internal::__pattern_partition(__tag, ::std::forward<_ExecutionPolicy>(__exec), __first + 1, __last,
2703-
[&__comp, __first](const _Tp& __x) { return __comp(__x, *__first); });
2706+
__x = __internal::__pattern_partition(
2707+
__tag, std::forward<_ExecutionPolicy>(__exec), __first + 1, __last,
2708+
[&__comp, __first](const _Tp& __x) { return std::invoke(__comp, __x, *__first); });
27042709
--__x;
27052710
if (__x != __first)
27062711
{
@@ -2715,7 +2720,7 @@ __pattern_nth_element(__parallel_tag<_IsVector> __tag, _ExecutionPolicy&& __exec
27152720
else if (__x - __nth < 0)
27162721
{
27172722
// if *x == *nth then we start the new partition at the next index where *x != *nth
2718-
while (!__comp(*__nth, *__x) && !__comp(*__x, *__nth) && __x - __nth < 0)
2723+
while (!std::invoke(__comp, *__nth, *__x) && !std::invoke(__comp, *__x, *__nth) && __x - __nth < 0)
27192724
{
27202725
++__x;
27212726
}
@@ -3220,15 +3225,15 @@ __pattern_includes(__parallel_tag<_IsVector> __tag, _ExecutionPolicy&& __exec, _
32203225
if (__first1 == __last1 || __last2 - __first2 > __last1 - __first1 ||
32213226
// {1}: [**********] or [**********]
32223227
// {2}: [***********] [***********]
3223-
__comp(*__first2, *__first1) || __comp(*(__last1 - 1), *(__last2 - 1)))
3228+
std::invoke(__comp, *__first2, *__first1) || std::invoke(__comp, *(__last1 - 1), *(__last2 - 1)))
32243229
return false;
32253230

32263231
__first1 = ::std::lower_bound(__first1, __last1, *__first2, __comp);
32273232
if (__first1 == __last1)
32283233
return false;
32293234

32303235
if (__last2 - __first2 == 1)
3231-
return !__comp(*__first1, *__first2) && !__comp(*__first2, *__first1);
3236+
return !std::invoke(__comp, *__first1, *__first2) && !std::invoke(__comp, *__first2, *__first1);
32323237

32333238
return __internal::__except_handler([&]() {
32343239
return !__internal::__parallel_or(
@@ -3240,7 +3245,7 @@ __pattern_includes(__parallel_tag<_IsVector> __tag, _ExecutionPolicy&& __exec, _
32403245
//1. moving boundaries to "consume" subsequence of equal elements
32413246
auto __is_equal_sorted = [&__comp](_RandomAccessIterator2 __a, _RandomAccessIterator2 __b) -> bool {
32423247
//enough one call of __comp due to compared couple belongs to one sorted sequence
3243-
return !__comp(*__a, *__b);
3248+
return !std::invoke(__comp, *__a, *__b);
32443249
};
32453250

32463251
//1.1 left bound, case "aaa[aaaxyz...]" - searching "x"
@@ -3260,8 +3265,8 @@ __pattern_includes(__parallel_tag<_IsVector> __tag, _ExecutionPolicy&& __exec, _
32603265
//2. testing is __a subsequence of the second range included into the first range
32613266
auto __b = ::std::lower_bound(__first1, __last1, *__i, __comp);
32623267

3263-
assert(!__comp(*(__last1 - 1), *__b));
3264-
assert(!__comp(*(__j - 1), *__i));
3268+
assert(!std::invoke(__comp, *(__last1 - 1), *__b));
3269+
assert(!std::invoke(__comp, *(__j - 1), *__i));
32653270
return !::std::includes(__b, __last1, __i, __j, __comp);
32663271
});
32673272
});
@@ -3846,9 +3851,10 @@ __brick_is_heap_until(_RandomAccessIterator __first, _RandomAccessIterator __las
38463851
/* __is_vector = */ ::std::true_type) noexcept
38473852
{
38483853
using _SizeType = typename std::iterator_traits<_RandomAccessIterator>::difference_type;
3849-
return __unseq_backend::__simd_first(
3850-
__first, _SizeType(0), __last - __first,
3851-
[&__comp](_RandomAccessIterator __it, _SizeType __i) { return __comp(__it[(__i - 1) / 2], __it[__i]); });
3854+
return __unseq_backend::__simd_first(__first, _SizeType(0), __last - __first,
3855+
[&__comp](_RandomAccessIterator __it, _SizeType __i) {
3856+
return std::invoke(__comp, __it[(__i - 1) / 2], __it[__i]);
3857+
});
38523858
}
38533859

38543860
template <class _Tag, class _ExecutionPolicy, class _RandomAccessIterator, class _Compare>
@@ -3868,7 +3874,7 @@ __is_heap_until_local(_RandomAccessIterator __first, _DifferenceType __begin, _D
38683874
{
38693875
_DifferenceType __i = __begin;
38703876
for (; __i < __end; ++__i)
3871-
if (__comp(__first[(__i - 1) / 2], __first[__i]))
3877+
if (std::invoke(__comp, __first[(__i - 1) / 2], __first[__i]))
38723878
break;
38733879
return __first + __i;
38743880
}
@@ -3878,9 +3884,10 @@ _RandomAccessIterator
38783884
__is_heap_until_local(_RandomAccessIterator __first, _DifferenceType __begin, _DifferenceType __end, _Compare __comp,
38793885
/* __is_vector = */ ::std::true_type) noexcept
38803886
{
3881-
return __unseq_backend::__simd_first(
3882-
__first, __begin, __end,
3883-
[&__comp](_RandomAccessIterator __it, _DifferenceType __i) { return __comp(__it[(__i - 1) / 2], __it[__i]); });
3887+
return __unseq_backend::__simd_first(__first, __begin, __end,
3888+
[&__comp](_RandomAccessIterator __it, _DifferenceType __i) {
3889+
return std::invoke(__comp, __it[(__i - 1) / 2], __it[__i]);
3890+
});
38843891
}
38853892

38863893
template <class _IsVector, class _ExecutionPolicy, class _RandomAccessIterator, class _Compare>
@@ -3916,7 +3923,7 @@ __brick_is_heap(_RandomAccessIterator __first, _RandomAccessIterator __last, _Co
39163923
/* __is_vector = */ ::std::true_type) noexcept
39173924
{
39183925
return !__unseq_backend::__simd_or_iter(__first, __last - __first, [__first, &__comp](_RandomAccessIterator __it) {
3919-
return __comp(*(__first + (__it - __first - 1) / 2), *__it);
3926+
return std::invoke(__comp, *(__first + (__it - __first - 1) / 2), *__it);
39203927
});
39213928
}
39223929

@@ -3933,10 +3940,10 @@ bool
39333940
__is_heap_local(_RandomAccessIterator __first, _DifferenceType __begin, _DifferenceType __end, _Compare __comp,
39343941
/* __is_vector = */ ::std::true_type) noexcept
39353942
{
3936-
return !__unseq_backend::__simd_or_iter(__first + __begin, __end - __begin,
3937-
[__first, &__comp](_RandomAccessIterator __it) {
3938-
return __comp(*(__first + (__it - __first - 1) / 2), *__it);
3939-
});
3943+
return !__unseq_backend::__simd_or_iter(
3944+
__first + __begin, __end - __begin, [__first, &__comp](_RandomAccessIterator __it) {
3945+
return std::invoke(__comp, *(__first + (__it - __first - 1) / 2), *__it);
3946+
});
39403947
}
39413948

39423949
template <class _Tag, class _ExecutionPolicy, class _RandomAccessIterator, class _Compare>
@@ -4205,16 +4212,16 @@ __brick_lexicographical_compare(_RandomAccessIterator1 __first1, _RandomAccessIt
42054212
auto __n = ::std::min(__last1 - __first1, __last2 - __first2);
42064213
::std::pair<_RandomAccessIterator1, _RandomAccessIterator2> __result = __unseq_backend::__simd_first(
42074214
__first1, __n, __first2, [__comp](const ref_type1 __x, const ref_type2 __y) mutable {
4208-
return __comp(__x, __y) || __comp(__y, __x);
4215+
return std::invoke(__comp, __x, __y) || std::invoke(__comp, __y, __x);
42094216
});
42104217

42114218
if (__result.first == __last1 && __result.second != __last2)
42124219
{ // if first sequence shorter than second
4213-
return !__comp(*__result.second, *__result.first);
4220+
return !std::invoke(__comp, *__result.second, *__result.first);
42144221
}
42154222
else
42164223
{ // if second sequence shorter than first or both have the same number of elements
4217-
return __comp(*__result.first, *__result.second);
4224+
return std::invoke(__comp, *__result.first, *__result.second);
42184225
}
42194226
}
42204227
}
@@ -4261,7 +4268,7 @@ __pattern_lexicographical_compare(__parallel_tag<_IsVector> __tag, _ExecutionPol
42614268
return __internal::__brick_mismatch(
42624269
__i, __j, __first2 + (__i - __first1), __first2 + (__j - __first1),
42634270
[&__comp](const _RefType1 __x, const _RefType2 __y) {
4264-
return !__comp(__x, __y) && !__comp(__y, __x);
4271+
return !std::invoke(__comp, __x, __y) && !std::invoke(__comp, __y, __x);
42654272
},
42664273
_IsVector{})
42674274
.first;
@@ -4270,11 +4277,11 @@ __pattern_lexicographical_compare(__parallel_tag<_IsVector> __tag, _ExecutionPol
42704277

42714278
if (__result == __last1 && __first2 + (__result - __first1) != __last2)
42724279
{ // if first sequence shorter than second
4273-
return !__comp(*(__first2 + (__result - __first1)), *__result);
4280+
return !std::invoke(__comp, *(__first2 + (__result - __first1)), *__result);
42744281
}
42754282
else
42764283
{ // if second sequence shorter than first or both have the same number of elements
4277-
return __comp(*__result, *(__first2 + (__result - __first1)));
4284+
return std::invoke(__comp, *__result, *(__first2 + (__result - __first1)));
42784285
}
42794286
});
42804287
}

include/oneapi/dpl/pstl/hetero/algorithm_impl_hetero.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1415,8 +1415,8 @@ struct __pattern_lexicographical_compare_transform_fn
14151415
auto const& __s1_val = __acc1[__gidx];
14161416
auto const& __s2_val = __acc2[__gidx];
14171417

1418-
::std::int32_t __is_s1_val_less = __comp(__s1_val, __s2_val);
1419-
::std::int32_t __is_s1_val_greater = __comp(__s2_val, __s1_val);
1418+
std::int32_t __is_s1_val_less = std::invoke(__comp, __s1_val, __s2_val);
1419+
std::int32_t __is_s1_val_greater = std::invoke(__comp, __s2_val, __s1_val);
14201420

14211421
// 1 if __s1_val < __s2_val, -1 if __s1_val < __s2_val, 0 if __s1_val == __s2_val
14221422
return _ReduceValueType{1 * __is_s1_val_less - 1 * __is_s1_val_greater};

include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_merge_sort.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ struct __subgroup_bubble_sorter
4848
{
4949
auto& __first_item = __storage_acc[j - 1];
5050
auto& __second_item = __storage_acc[j];
51-
if (__comp(__second_item, __first_item))
51+
if (std::invoke(__comp, __second_item, __first_item))
5252
{
5353
using std::swap;
5454
swap(__first_item, __second_item);

include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_radix_sort.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include <utility>
2222
#include <cstdint>
2323
#include <algorithm>
24+
#include <functional> // for std::invoke
2425

2526
#include "sycl_defs.h"
2627
#include "parallel_backend_sycl_utils.h"
@@ -219,7 +220,7 @@ __radix_sort_count_submit(sycl::queue& __q, std::size_t __segments, std::size_t
219220
__val_idx += __wg_size)
220221
{
221222
// get the bucket for the bit-ordered input value, applying the offset and mask for radix bits
222-
auto __val = __order_preserving_cast<__is_ascending>(__proj(__val_rng[__val_idx]));
223+
auto __val = __order_preserving_cast<__is_ascending>(std::invoke(__proj, __val_rng[__val_idx]));
223224
::std::uint32_t __bucket = __get_bucket<(1 << __radix_bits) - 1>(__val, __radix_offset);
224225
// increment counter for this bit bucket
225226
++__count_arr[__bucket];
@@ -597,7 +598,7 @@ __radix_sort_reorder_submit(sycl::queue& __q, std::size_t __segments, std::size_
597598
_ValueT __in_val = std::move(__input_rng[__val_idx]);
598599
// get the bucket for the bit-ordered input value, applying the offset and mask for radix bits
599600
::std::uint32_t __bucket = __get_bucket<(1 << __radix_bits) - 1>(
600-
__order_preserving_cast<__is_ascending>(__proj(__in_val)), __radix_offset);
601+
__order_preserving_cast<__is_ascending>(std::invoke(__proj, __in_val)), __radix_offset);
601602

602603
const auto __new_offset_idx = __peer_prefix_hlp.__peer_contribution(__bucket, __offset_arr);
603604
__output_rng[__new_offset_idx] = std::move(__in_val);
@@ -614,7 +615,7 @@ __radix_sort_reorder_submit(sycl::queue& __q, std::size_t __segments, std::size_
614615
new (&__in_val.__v) _ValueT(std::move(__input_rng[__seg_end + __self_lidx]));
615616

616617
__bucket = __get_bucket<(1 << __radix_bits) - 1>(
617-
__order_preserving_cast<__is_ascending>(__proj(__in_val.__v)), __radix_offset);
618+
__order_preserving_cast<__is_ascending>(std::invoke(__proj, __in_val.__v)), __radix_offset);
618619
}
619620
const auto __new_offset_idx = __peer_prefix_hlp.__peer_contribution(__bucket, __offset_arr);
620621
if (__self_lidx < __residual)

include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_radix_sort_one_wg.h

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818

1919
#include "sycl_traits.h" //SYCL traits specialization for some oneDPL types.
2020

21+
#include <functional> // for std::invoke
22+
2123
//The file is an internal file and the code of that file is included by a major file into the following namespaces:
2224
//namespace oneapi
2325
//{
@@ -214,9 +216,13 @@ struct __subgroup_radix_sort
214216
for (uint16_t __i = 0; __i < __block_size; ++__i)
215217
{
216218
const uint16_t __idx = __wi * __block_size + __i;
217-
const uint16_t __bin = __idx < __n ? __get_bucket</*mask*/ __bin_count - 1>(
218-
__order_preserving_cast<__is_asc>(__proj(__values.__v[__i])), __begin_bit)
219-
: __bin_count - 1/*default bin for out of range elements (when idx >= n)*/;
219+
const uint16_t __bin =
220+
__idx < __n
221+
? __get_bucket</*mask*/ __bin_count - 1>(
222+
__order_preserving_cast<__is_asc>(
223+
std::invoke(__proj, __values.__v[__i])),
224+
__begin_bit)
225+
: __bin_count - 1 /*default bin for out of range elements (when idx >= n)*/;
220226

221227
//"counting" and local offset calculation
222228
__counters[__i] = &__pcounter[__bin * __wg_size];

include/oneapi/dpl/pstl/hetero/utils_hetero.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,9 @@ struct __pattern_minmax_element_reduce_fn
6363
auto __chosen_for_min = __a;
6464
auto __chosen_for_max = __b;
6565

66-
if (__comp(get<2>(__b), get<2>(__a)))
66+
if (std::invoke(__comp, get<2>(__b), get<2>(__a)))
6767
__chosen_for_min = std::move(__b);
68-
if (__comp(get<3>(__b), get<3>(__a)))
68+
if (std::invoke(__comp, get<3>(__b), get<3>(__a)))
6969
__chosen_for_max = std::move(__a);
7070
return _ReduceValueType{get<0>(__chosen_for_min), get<1>(__chosen_for_max), get<2>(__chosen_for_min),
7171
get<3>(__chosen_for_max)};
@@ -87,7 +87,7 @@ struct __pattern_min_element_reduce_fn
8787
{
8888
// This operator doesn't track the lowest found index in case of equal min. or max. values. Thus, this operator is
8989
// not commutative.
90-
if (__comp(get<1>(__b), get<1>(__a)))
90+
if (std::invoke(__comp, get<1>(__b), get<1>(__a)))
9191
{
9292
return __b;
9393
}
@@ -97,8 +97,8 @@ struct __pattern_min_element_reduce_fn
9797
{
9898
// This operator keeps track of the lowest found index in case of equal min. or max. values. Thus, this operator is
9999
// commutative.
100-
bool _is_a_lt_b = __comp(get<1>(__a), get<1>(__b));
101-
bool _is_b_lt_a = __comp(get<1>(__b), get<1>(__a));
100+
bool _is_a_lt_b = std::invoke(__comp, get<1>(__a), get<1>(__b));
101+
bool _is_b_lt_a = std::invoke(__comp, get<1>(__b), get<1>(__a));
102102

103103
if (_is_b_lt_a || (!_is_a_lt_b && get<0>(__b) < get<0>(__a)))
104104
{

include/oneapi/dpl/pstl/parallel_backend_tbb.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include <cassert>
2222
#include <algorithm>
2323
#include <type_traits>
24+
#include <functional> // for std::invoke
2425

2526
#include "parallel_backend_utils.h"
2627
#include "execution_impl.h"
@@ -869,12 +870,12 @@ class __merge_func
869870
{
870871
assert(::std::is_sorted(_M_x_beg + _M_xs, _M_x_beg + _M_xs + __kx, _M_comp));
871872
assert(::std::is_sorted(_M_x_beg + _M_ys, _M_x_beg + _M_ys + __ky, _M_comp));
872-
return !_M_comp(*(_M_x_beg + _M_ys), *(_M_x_beg + _M_xs + __kx - 1));
873+
return !std::invoke(_M_comp, *(_M_x_beg + _M_ys), *(_M_x_beg + _M_xs + __kx - 1));
873874
}
874875

875876
assert(::std::is_sorted(_M_z_beg + _M_xs, _M_z_beg + _M_xs + __kx, _M_comp));
876877
assert(::std::is_sorted(_M_z_beg + _M_ys, _M_z_beg + _M_ys + __ky, _M_comp));
877-
return !_M_comp(*(_M_z_beg + _M_zs + __nx), *(_M_z_beg + _M_zs + __kx - 1));
878+
return !std::invoke(_M_comp, *(_M_z_beg + _M_zs + __nx), *(_M_z_beg + _M_zs + __kx - 1));
878879
}
879880
void
880881
move_x_range()

include/oneapi/dpl/pstl/parallel_backend_utils.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
#include <cassert>
2727
#include "utils.h"
2828
#include "memory_fwd.h"
29-
#include "functional_impl.h" // for oneapi::dpl::identity
29+
#include "functional_impl.h" // for oneapi::dpl::identity, std::invoke
3030

3131
namespace oneapi
3232
{
@@ -110,7 +110,7 @@ struct __serial_move_merge
110110
{
111111
for (;;)
112112
{
113-
if (__comp(*__ys, *__xs))
113+
if (std::invoke(__comp, *__ys, *__xs))
114114
{
115115
const auto __i = __zs - __zs_beg;
116116
if (__i < __nx)

0 commit comments

Comments
 (0)