Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions include/oneapi/dpl/internal/async_impl/async_impl_hetero.h
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,26 @@ __pattern_transform_scan_async(__hetero_tag<_BackendTag> __tag, _ExecutionPolicy
__result, __unary_op, _InitType{}, __binary_op, _Inclusive{});
}

//------------------------------------------------------------------------
// sort
//------------------------------------------------------------------------

template <class _ExecutionPolicy, class _Iterator, class _Compare>
__future<sycl::event, std::shared_ptr<__result_and_scratch_storage_base>>
__pattern_stable_sort_async(__hetero_tag<_BackendTag>, _ExecutionPolicy&& __exec, _Iterator __first, _Iterator __last,
_Compare __comp)
{
assert(__last - __first >= 2);

auto __keep = oneapi::dpl::__ranges::__get_sycl_range<__par_backend_hetero::access_mode::read_write, _Iterator>();
auto __buf = __keep(__first, __last);

auto [__e, __p_unique] = __par_backend_hetero::__parallel_stable_sort(
_BackendTag{}, std::forward<_ExecutionPolicy>(__exec), __buf.all_view(), __comp, oneapi::dpl::identity{});

return {std::move(__e), std::shared_ptr{__p_unique}};
}

} // namespace __internal
} // namespace dpl
} // namespace oneapi
Expand Down
19 changes: 7 additions & 12 deletions include/oneapi/dpl/internal/async_impl/glue_async_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,17 +93,12 @@ template <class _ExecutionPolicy, class _Iterator, class _Compare, class... _Eve
auto
sort_async(_ExecutionPolicy&& __exec, _Iterator __first, _Iterator __last, _Compare __comp, _Events&&... __dependencies)
{
wait_for_all(::std::forward<_Events>(__dependencies)...);
assert(__last - __first >= 2);

auto __keep = oneapi::dpl::__ranges::__get_sycl_range<__par_backend_hetero::access_mode::read_write, _Iterator>();
auto __buf = __keep(__first, __last);

const auto __dispatch_tag = oneapi::dpl::__internal::__select_backend(__exec, __first);
using __backend_tag = typename decltype(__dispatch_tag)::__backend_tag;

return __par_backend_hetero::__parallel_stable_sort(__backend_tag{}, ::std::forward<_ExecutionPolicy>(__exec),
__buf.all_view(), __comp, oneapi::dpl::identity{});
wait_for_all(std::forward<_Events>(__dependencies)...);

return oneapi::dpl::__internal::__pattern_stable_sort_async(__dispatch_tag, std::forward<_ExecutionPolicy>(__exec),
__first, __last, __comp);
}

template <class _ExecutionPolicy, class _RandomAccessIterator, class... _Events,
Expand All @@ -112,9 +107,9 @@ auto
sort_async(_ExecutionPolicy&& __exec, _RandomAccessIterator __first, _RandomAccessIterator __last,
_Events&&... __dependencies)
{
using _ValueType = typename ::std::iterator_traits<_RandomAccessIterator>::value_type;
return sort_async(::std::forward<_ExecutionPolicy>(__exec), __first, __last, ::std::less<_ValueType>(),
::std::forward<_Events>(__dependencies)...);
using _ValueType = typename std::iterator_traits<_RandomAccessIterator>::value_type;
return sort_async(std::forward<_ExecutionPolicy>(__exec), __first, __last, std::less<_ValueType>(),
std::forward<_Events>(__dependencies)...);
}

// [async.for_each]
Expand Down
20 changes: 10 additions & 10 deletions include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl.h
Original file line number Diff line number Diff line change
Expand Up @@ -2205,9 +2205,9 @@ struct __is_radix_sort_usable_for_type
{
static constexpr bool value =
#if _ONEDPL_USE_RADIX_SORT
(::std::is_arithmetic_v<_T> || ::std::is_same_v<sycl::half, _T>) &&
(__internal::__is_comp_ascending<::std::decay_t<_Compare>>::value ||
__internal::__is_comp_descending<::std::decay_t<_Compare>>::value);
(std::is_arithmetic_v<_T> || std::is_same_v<sycl::half, _T>) &&
(__internal::__is_comp_ascending<std::decay_t<_Compare>>::value ||
__internal::__is_comp_descending<std::decay_t<_Compare>>::value);
#else
false;
#endif // _ONEDPL_USE_RADIX_SORT
Expand All @@ -2216,26 +2216,26 @@ struct __is_radix_sort_usable_for_type
#if _ONEDPL_USE_RADIX_SORT
template <
typename _ExecutionPolicy, typename _Range, typename _Compare, typename _Proj,
::std::enable_if_t<
std::enable_if_t<
__is_radix_sort_usable_for_type<oneapi::dpl::__internal::__key_t<_Proj, _Range>, _Compare>::value, int> = 0>
__future<sycl::event>
__parallel_stable_sort(oneapi::dpl::__internal::__device_backend_tag __backend_tag, _ExecutionPolicy&& __exec,
_Range&& __rng, _Compare, _Proj __proj)
{
return __parallel_radix_sort<__internal::__is_comp_ascending<::std::decay_t<_Compare>>::value>(
__backend_tag, ::std::forward<_ExecutionPolicy>(__exec), ::std::forward<_Range>(__rng), __proj);
return __parallel_radix_sort<__internal::__is_comp_ascending<std::decay_t<_Compare>>::value>(
__backend_tag, std::forward<_ExecutionPolicy>(__exec), std::forward<_Range>(__rng), __proj);
}
#endif // _ONEDPL_USE_RADIX_SORT

template <
typename _ExecutionPolicy, typename _Range, typename _Compare, typename _Proj,
::std::enable_if_t<
std::enable_if_t<
!__is_radix_sort_usable_for_type<oneapi::dpl::__internal::__key_t<_Proj, _Range>, _Compare>::value, int> = 0>
__future<sycl::event, std::shared_ptr<__result_and_scratch_storage_base>>
__future<sycl::event, std::unique_ptr<__result_and_scratch_storage_base>>
__parallel_stable_sort(oneapi::dpl::__internal::__device_backend_tag __backend_tag, _ExecutionPolicy&& __exec,
_Range&& __rng, _Compare __comp, _Proj __proj)
{
return __parallel_sort_impl(__backend_tag, ::std::forward<_ExecutionPolicy>(__exec), ::std::forward<_Range>(__rng),
return __parallel_sort_impl(__backend_tag, std::forward<_ExecutionPolicy>(__exec), std::forward<_Range>(__rng),
oneapi::dpl::__internal::__compare<_Compare, _Proj>{__comp, __proj});
}

Expand All @@ -2256,7 +2256,7 @@ __parallel_partial_sort(oneapi::dpl::__internal::__device_backend_tag __backend_
auto __keep = oneapi::dpl::__ranges::__get_sycl_range<__par_backend_hetero::access_mode::read_write, _Iterator>();
auto __buf = __keep(__first, __last);

return __parallel_partial_sort_impl(__backend_tag, ::std::forward<_ExecutionPolicy>(__exec), __buf.all_view(),
return __parallel_partial_sort_impl(__backend_tag, std::forward<_ExecutionPolicy>(__exec), __buf.all_view(),
__partial_merge_kernel<decltype(__mid_idx)>{__mid_idx}, __comp);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ template <typename _OutSizeLimit, typename _IdType, typename... _Name>
struct __parallel_merge_submitter<_OutSizeLimit, _IdType, __internal::__optional_kernel_name<_Name...>>
{
template <typename _Range1, typename _Range2, typename _Range3, typename _Compare>
__future<sycl::event, std::shared_ptr<__result_and_scratch_storage_base>>
__future<sycl::event, std::unique_ptr<__result_and_scratch_storage_base>>
operator()(sycl::queue& __q, _Range1&& __rng1, _Range2&& __rng2, _Range3&& __rng3, _Compare __comp) const
{
const _IdType __n1 = __rng1.size();
Expand All @@ -232,7 +232,7 @@ struct __parallel_merge_submitter<_OutSizeLimit, _IdType, __internal::__optional
else
assert(__rng3.size() >= __n1 + __n2);

std::shared_ptr<__result_and_scratch_storage_base> __p_result_and_scratch_storage_base(
std::unique_ptr<__result_and_scratch_storage_base> __p_result_and_scratch_storage_base(
static_cast<__result_and_scratch_storage_base*>(__p_res_storage));

auto __event = __q.submit([&__rng1, &__rng2, &__rng3, __p_res_storage, __comp, __chunk, __steps, __n, __n1,
Expand Down Expand Up @@ -261,7 +261,7 @@ struct __parallel_merge_submitter<_OutSizeLimit, _IdType, __internal::__optional
});
});

// Save the raw pointer into a shared_ptr to return it in __future and extend the lifetime of the storage.
// Save the raw pointer into a unique_ptr to return it in __future and extend the lifetime of the storage.
// We should return the same thing in the second param of __future for compatibility
// with the returning value in __parallel_merge_submitter_large::operator()
return __future{std::move(__event), std::move(__p_result_and_scratch_storage_base)};
Expand Down Expand Up @@ -423,7 +423,7 @@ struct __parallel_merge_submitter_large<_OutSizeLimit, _IdType, _CustomName,

public:
template <typename _Range1, typename _Range2, typename _Range3, typename _Compare>
__future<sycl::event, std::shared_ptr<__result_and_scratch_storage_base>>
__future<sycl::event, std::unique_ptr<__result_and_scratch_storage_base>>
operator()(sycl::queue& __q, _Range1&& __rng1, _Range2&& __rng2, _Range3&& __rng3, _Compare __comp) const
{
const _IdType __n1 = __rng1.size();
Expand All @@ -445,8 +445,8 @@ struct __parallel_merge_submitter_large<_OutSizeLimit, _IdType, _CustomName,
auto __p_base_diagonals_sp_global_storage =
new __result_and_scratch_storage_t(__q, __nd_range_params.base_diag_count + 1);

// Save the raw pointer into a shared_ptr to return it in __future and extend the lifetime of the storage.
std::shared_ptr<__result_and_scratch_storage_base> __p_result_and_scratch_storage_base(
// Save the raw pointer into a unique_ptr to return it in __future and extend the lifetime of the storage.
std::unique_ptr<__result_and_scratch_storage_base> __p_result_and_scratch_storage_base(
static_cast<__result_and_scratch_storage_base*>(__p_base_diagonals_sp_global_storage));

// Find split-points on the base diagonals
Expand Down Expand Up @@ -486,7 +486,7 @@ __get_starting_size_limit_for_large_submitter<int>()

template <typename _ExecutionPolicy, typename _Range1, typename _Range2, typename _Range3, typename _Compare,
typename _OutSizeLimit = std::false_type>
__future<sycl::event, std::shared_ptr<__result_and_scratch_storage_base>>
__future<sycl::event, std::unique_ptr<__result_and_scratch_storage_base>>
__parallel_merge(oneapi::dpl::__internal::__device_backend_tag, _ExecutionPolicy&& __exec, _Range1&& __rng1,
_Range2&& __rng2, _Range3&& __rng3, _Compare __comp, _OutSizeLimit = {})
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -565,7 +565,7 @@ struct __merge_sort_global_submitter<_IndexT, __internal::__optional_kernel_name

public:
template <typename _Range, typename _Compare, typename _TempBuf, typename _LeafSizeT>
std::tuple<sycl::event, bool, std::shared_ptr<__result_and_scratch_storage_base>>
std::tuple<sycl::event, bool, std::unique_ptr<__result_and_scratch_storage_base>>
operator()(sycl::queue& __q, _Range& __rng, _Compare __comp, _LeafSizeT __leaf_size, _TempBuf& __temp_buf,
sycl::event __event_chain) const
{
Expand All @@ -592,8 +592,8 @@ struct __merge_sort_global_submitter<_IndexT, __internal::__optional_kernel_name
// Storage to save split-points on each base diagonal + 1 (for the right base diagonal in the last work-group)
__base_diagonals_sp_storage_t* __p_base_diagonals_sp_global_storage = nullptr;

// shared_ptr instance to return it in __future and extend the lifetime of the storage.
std::shared_ptr<__result_and_scratch_storage_base> __p_result_and_scratch_storage_base;
// unique_ptr instance to return it in __future and extend the lifetime of the storage.
std::unique_ptr<__result_and_scratch_storage_base> __p_result_and_scratch_storage_base;

// Max amount of base diagonals
const std::size_t __max_base_diags_count =
Expand All @@ -616,7 +616,7 @@ struct __merge_sort_global_submitter<_IndexT, __internal::__optional_kernel_name
__p_base_diagonals_sp_global_storage =
new __base_diagonals_sp_storage_t(__q, __max_base_diags_count);

// Save the raw pointer into a shared_ptr to return it in __future and extend the lifetime of the storage.
// Save the raw pointer into a unique_ptr to return it in __future and extend the lifetime of the storage.
__p_result_and_scratch_storage_base.reset(
static_cast<__result_and_scratch_storage_base*>(__p_base_diagonals_sp_global_storage));
}
Expand Down Expand Up @@ -693,7 +693,7 @@ template <typename... _Name>
class __sort_copy_back_kernel;

template <typename _CustomName, typename _IndexT, typename _Range, typename _Compare, typename _LeafSorter>
__future<sycl::event, std::shared_ptr<__result_and_scratch_storage_base>>
__future<sycl::event, std::unique_ptr<__result_and_scratch_storage_base>>
__merge_sort(sycl::queue& __q, _Range&& __rng, _Compare __comp, _LeafSorter& __leaf_sorter)
{
using _Tp = oneapi::dpl::__internal::__value_t<_Range>;
Expand Down Expand Up @@ -733,7 +733,7 @@ __merge_sort(sycl::queue& __q, _Range&& __rng, _Compare __comp, _LeafSorter& __l
}

template <typename _CustomName, typename _IndexT, typename _Range, typename _Compare>
__future<sycl::event, std::shared_ptr<__result_and_scratch_storage_base>>
__future<sycl::event, std::unique_ptr<__result_and_scratch_storage_base>>
__submit_selecting_leaf(sycl::queue& __q, _Range&& __rng, _Compare __comp)
{
using _Leaf = __leaf_sorter<std::decay_t<_Range>, _Compare>;
Expand Down Expand Up @@ -786,7 +786,7 @@ __submit_selecting_leaf(sycl::queue& __q, _Range&& __rng, _Compare __comp)
};

template <typename _ExecutionPolicy, typename _Range, typename _Compare>
__future<sycl::event, std::shared_ptr<__result_and_scratch_storage_base>>
__future<sycl::event, std::unique_ptr<__result_and_scratch_storage_base>>
__parallel_sort_impl(oneapi::dpl::__internal::__device_backend_tag, _ExecutionPolicy&& __exec, _Range&& __rng,
_Compare __comp)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -729,7 +729,7 @@ struct __deferrable_mode
{
};

//A contract for future class: <sycl::event or other event, a value, sycl::buffers..., or __usm_host_or_buffer_storage>
//A contract for future class: <sycl::event or other event, a value, sycl::buffers..., or __result_and_scratch_storage>
//Impl details: inheritance (private) instead of aggregation for enabling the empty base optimization.
template <typename _Event, typename... _Args>
class __future : private std::tuple<_Args...>
Expand Down Expand Up @@ -770,7 +770,8 @@ class __future : private std::tuple<_Args...>
}

public:
__future(_Event __e, _Args... __args) : std::tuple<_Args...>(__args...), __my_event(__e) {}
template<typename... _UArgs>
__future(_Event __e, _UArgs&&... __args) : std::tuple<_Args...>(std::forward<_UArgs>(__args)...), __my_event(__e) {}
__future(_Event __e, std::tuple<_Args...> __t) : std::tuple<_Args...>(__t), __my_event(__e) {}

auto
Expand Down Expand Up @@ -825,6 +826,9 @@ class __future : private std::tuple<_Args...>
return __future<_Event, _T, _Args...>(__my_event, new_tuple);
}
};
// Deduction guide for __future
template<typename _Event, typename... _UArgs>
__future(_Event, _UArgs&&...) -> __future<_Event, std::remove_cv_t<std::remove_reference_t<_UArgs>>...>;

// Invoke a callable and pass a compile-time integer based on a provided run-time integer.
// The compile-time integer that will be provided to the callable is defined as the smallest
Expand Down
Loading