Skip to content

Commit a9eab90

Browse files
committed
Use unique_ptr instead of shared_ptr
1 parent a6461f7 commit a9eab90

File tree

6 files changed

+43
-39
lines changed

6 files changed

+43
-39
lines changed

include/oneapi/dpl/internal/async_impl/async_impl_hetero.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,26 @@ __pattern_transform_scan_async(__hetero_tag<_BackendTag> __tag, _ExecutionPolicy
240240
__result, __unary_op, _InitType{}, __binary_op, _Inclusive{});
241241
}
242242

243+
//------------------------------------------------------------------------
244+
// sort
245+
//------------------------------------------------------------------------
246+
247+
template <class _ExecutionPolicy, class _Iterator, class _Compare>
248+
__future<sycl::event, std::shared_ptr<__result_and_scratch_storage_base>>
249+
__pattern_stable_sort_async(__hetero_tag<_BackendTag>, _ExecutionPolicy&& __exec, _Iterator __first, _Iterator __last,
250+
_Compare __comp)
251+
{
252+
assert(__last - __first >= 2);
253+
254+
auto __keep = oneapi::dpl::__ranges::__get_sycl_range<__par_backend_hetero::access_mode::read_write, _Iterator>();
255+
auto __buf = __keep(__first, __last);
256+
257+
auto [__e, __p_unique] = __par_backend_hetero::__parallel_stable_sort(
258+
_BackendTag{}, std::forward<_ExecutionPolicy>(__exec), __buf.all_view(), __comp, oneapi::dpl::identity{});
259+
260+
return {std::move(__e), std::shared_ptr{__p_unique}};
261+
}
262+
243263
} // namespace __internal
244264
} // namespace dpl
245265
} // namespace oneapi

include/oneapi/dpl/internal/async_impl/glue_async_impl.h

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -93,17 +93,12 @@ template <class _ExecutionPolicy, class _Iterator, class _Compare, class... _Eve
9393
auto
9494
sort_async(_ExecutionPolicy&& __exec, _Iterator __first, _Iterator __last, _Compare __comp, _Events&&... __dependencies)
9595
{
96-
wait_for_all(::std::forward<_Events>(__dependencies)...);
97-
assert(__last - __first >= 2);
98-
99-
auto __keep = oneapi::dpl::__ranges::__get_sycl_range<__par_backend_hetero::access_mode::read_write, _Iterator>();
100-
auto __buf = __keep(__first, __last);
101-
10296
const auto __dispatch_tag = oneapi::dpl::__internal::__select_backend(__exec, __first);
103-
using __backend_tag = typename decltype(__dispatch_tag)::__backend_tag;
10497

105-
return __par_backend_hetero::__parallel_stable_sort(__backend_tag{}, ::std::forward<_ExecutionPolicy>(__exec),
106-
__buf.all_view(), __comp, oneapi::dpl::identity{});
98+
wait_for_all(std::forward<_Events>(__dependencies)...);
99+
100+
return oneapi::dpl::__internal::__pattern_stable_sort_async(
101+
__dispatch_tag, std::forward<_ExecutionPolicy>(__exec), __first, __last, __comp);
107102
}
108103

109104
template <class _ExecutionPolicy, class _RandomAccessIterator, class... _Events,
@@ -112,9 +107,9 @@ auto
112107
sort_async(_ExecutionPolicy&& __exec, _RandomAccessIterator __first, _RandomAccessIterator __last,
113108
_Events&&... __dependencies)
114109
{
115-
using _ValueType = typename ::std::iterator_traits<_RandomAccessIterator>::value_type;
116-
return sort_async(::std::forward<_ExecutionPolicy>(__exec), __first, __last, ::std::less<_ValueType>(),
117-
::std::forward<_Events>(__dependencies)...);
110+
using _ValueType = typename std::iterator_traits<_RandomAccessIterator>::value_type;
111+
return sort_async(std::forward<_ExecutionPolicy>(__exec), __first, __last, std::less<_ValueType>(),
112+
std::forward<_Events>(__dependencies)...);
118113
}
119114

120115
// [async.for_each]

include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2231,7 +2231,7 @@ template <
22312231
typename _ExecutionPolicy, typename _Range, typename _Compare, typename _Proj,
22322232
::std::enable_if_t<
22332233
!__is_radix_sort_usable_for_type<oneapi::dpl::__internal::__key_t<_Proj, _Range>, _Compare>::value, int> = 0>
2234-
__future<sycl::event, std::shared_ptr<__result_and_scratch_storage_base>>
2234+
__future<sycl::event, std::unique_ptr<__result_and_scratch_storage_base>>
22352235
__parallel_stable_sort(oneapi::dpl::__internal::__device_backend_tag __backend_tag, _ExecutionPolicy&& __exec,
22362236
_Range&& __rng, _Compare __comp, _Proj __proj)
22372237
{

include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_merge.h

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ template <typename _OutSizeLimit, typename _IdType, typename... _Name>
205205
struct __parallel_merge_submitter<_OutSizeLimit, _IdType, __internal::__optional_kernel_name<_Name...>>
206206
{
207207
template <typename _Range1, typename _Range2, typename _Range3, typename _Compare>
208-
__future<sycl::event, std::shared_ptr<__result_and_scratch_storage_base>>
208+
__future<sycl::event, std::unique_ptr<__result_and_scratch_storage_base>>
209209
operator()(sycl::queue& __q, _Range1&& __rng1, _Range2&& __rng2, _Range3&& __rng3, _Compare __comp) const
210210
{
211211
const _IdType __n1 = __rng1.size();
@@ -232,9 +232,6 @@ struct __parallel_merge_submitter<_OutSizeLimit, _IdType, __internal::__optional
232232
else
233233
assert(__rng3.size() >= __n1 + __n2);
234234

235-
std::shared_ptr<__result_and_scratch_storage_base> __p_result_and_scratch_storage_base(
236-
static_cast<__result_and_scratch_storage_base*>(__p_res_storage));
237-
238235
auto __event = __q.submit([&__rng1, &__rng2, &__rng3, __p_res_storage, __comp, __chunk, __steps, __n, __n1,
239236
__n2](sycl::handler& __cgh) {
240237
oneapi::dpl::__ranges::__require_access(__cgh, __rng1, __rng2, __rng3);
@@ -261,10 +258,10 @@ struct __parallel_merge_submitter<_OutSizeLimit, _IdType, __internal::__optional
261258
});
262259
});
263260

264-
// Save the raw pointer into a shared_ptr to return it in __future and extend the lifetime of the storage.
261+
// Save the raw pointer into a unique_ptr to return it in __future and extend the lifetime of the storage.
265262
// We should return the same thing in the second param of __future for compatibility
266263
// with the returning value in __parallel_merge_submitter_large::operator()
267-
return __future{std::move(__event), std::move(__p_result_and_scratch_storage_base)};
264+
return __future{std::move(__event), std::unique_ptr<__result_and_scratch_storage_base>{__p_res_storage}};
268265
}
269266

270267
private:
@@ -423,7 +420,7 @@ struct __parallel_merge_submitter_large<_OutSizeLimit, _IdType, _CustomName,
423420

424421
public:
425422
template <typename _Range1, typename _Range2, typename _Range3, typename _Compare>
426-
__future<sycl::event, std::shared_ptr<__result_and_scratch_storage_base>>
423+
__future<sycl::event, std::unique_ptr<__result_and_scratch_storage_base>>
427424
operator()(sycl::queue& __q, _Range1&& __rng1, _Range2&& __rng2, _Range3&& __rng3, _Compare __comp) const
428425
{
429426
const _IdType __n1 = __rng1.size();
@@ -445,10 +442,6 @@ struct __parallel_merge_submitter_large<_OutSizeLimit, _IdType, _CustomName,
445442
auto __p_base_diagonals_sp_global_storage =
446443
new __result_and_scratch_storage_t(__q, __nd_range_params.base_diag_count + 1);
447444

448-
// Save the raw pointer into a shared_ptr to return it in __future and extend the lifetime of the storage.
449-
std::shared_ptr<__result_and_scratch_storage_base> __p_result_and_scratch_storage_base(
450-
static_cast<__result_and_scratch_storage_base*>(__p_base_diagonals_sp_global_storage));
451-
452445
// Find split-points on the base diagonals
453446
sycl::event __event = eval_split_points_for_groups(__q, __rng1, __rng2, __n, __comp, __nd_range_params,
454447
*__p_base_diagonals_sp_global_storage);
@@ -457,7 +450,8 @@ struct __parallel_merge_submitter_large<_OutSizeLimit, _IdType, _CustomName,
457450
__event = run_parallel_merge(__event, __q, __rng1, __rng2, __rng3, __comp, __nd_range_params,
458451
*__p_base_diagonals_sp_global_storage);
459452

460-
return __future{std::move(__event), std::move(__p_result_and_scratch_storage_base)};
453+
return __future{std::move(__event), std::unique_ptr<__result_and_scratch_storage_base>{
454+
__p_base_diagonals_sp_global_storage}};
461455
}
462456
};
463457

@@ -486,7 +480,7 @@ __get_starting_size_limit_for_large_submitter<int>()
486480

487481
template <typename _ExecutionPolicy, typename _Range1, typename _Range2, typename _Range3, typename _Compare,
488482
typename _OutSizeLimit = std::false_type>
489-
__future<sycl::event, std::shared_ptr<__result_and_scratch_storage_base>>
483+
__future<sycl::event, std::unique_ptr<__result_and_scratch_storage_base>>
490484
__parallel_merge(oneapi::dpl::__internal::__device_backend_tag, _ExecutionPolicy&& __exec, _Range1&& __rng1,
491485
_Range2&& __rng2, _Range3&& __rng3, _Compare __comp, _OutSizeLimit = {})
492486
{

include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_merge_sort.h

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -565,7 +565,7 @@ struct __merge_sort_global_submitter<_IndexT, __internal::__optional_kernel_name
565565

566566
public:
567567
template <typename _Range, typename _Compare, typename _TempBuf, typename _LeafSizeT>
568-
std::tuple<sycl::event, bool, std::shared_ptr<__result_and_scratch_storage_base>>
568+
std::tuple<sycl::event, bool, std::unique_ptr<__result_and_scratch_storage_base>>
569569
operator()(sycl::queue& __q, _Range& __rng, _Compare __comp, _LeafSizeT __leaf_size, _TempBuf& __temp_buf,
570570
sycl::event __event_chain) const
571571
{
@@ -592,9 +592,6 @@ struct __merge_sort_global_submitter<_IndexT, __internal::__optional_kernel_name
592592
// Storage to save split-points on each base diagonal + 1 (for the right base diagonal in the last work-group)
593593
__base_diagonals_sp_storage_t* __p_base_diagonals_sp_global_storage = nullptr;
594594

595-
// shared_ptr instance to return it in __future and extend the lifetime of the storage.
596-
std::shared_ptr<__result_and_scratch_storage_base> __p_result_and_scratch_storage_base;
597-
598595
// Max amount of base diagonals
599596
const std::size_t __max_base_diags_count =
600597
get_max_base_diags_count(__q, __nd_range_params.chunk, __n) + __1_final_base_diag;
@@ -615,10 +612,6 @@ struct __merge_sort_global_submitter<_IndexT, __internal::__optional_kernel_name
615612
// Create storage to save split-points on each base diagonal + 1 (for the right base diagonal in the last work-group)
616613
__p_base_diagonals_sp_global_storage =
617614
new __base_diagonals_sp_storage_t(__q, __max_base_diags_count);
618-
619-
// Save the raw pointer into a shared_ptr to return it in __future and extend the lifetime of the storage.
620-
__p_result_and_scratch_storage_base.reset(
621-
static_cast<__result_and_scratch_storage_base*>(__p_base_diagonals_sp_global_storage));
622615
}
623616

624617
nd_range_params __nd_range_params_this =
@@ -649,7 +642,9 @@ struct __merge_sort_global_submitter<_IndexT, __internal::__optional_kernel_name
649642
__data_in_temp = !__data_in_temp;
650643
}
651644

652-
return {std::move(__event_chain), __data_in_temp, std::move(__p_result_and_scratch_storage_base)};
645+
// Save the raw pointer into a unique_ptr to return it in __future and extend the lifetime of the storage.
646+
return {std::move(__event_chain), __data_in_temp,
647+
std::unique_ptr<__result_and_scratch_storage_base>{__p_base_diagonals_sp_global_storage}};
653648
}
654649
};
655650

@@ -693,7 +688,7 @@ template <typename... _Name>
693688
class __sort_copy_back_kernel;
694689

695690
template <typename _CustomName, typename _IndexT, typename _Range, typename _Compare, typename _LeafSorter>
696-
__future<sycl::event, std::shared_ptr<__result_and_scratch_storage_base>>
691+
__future<sycl::event, std::unique_ptr<__result_and_scratch_storage_base>>
697692
__merge_sort(sycl::queue& __q, _Range&& __rng, _Compare __comp, _LeafSorter& __leaf_sorter)
698693
{
699694
using _Tp = oneapi::dpl::__internal::__value_t<_Range>;
@@ -733,7 +728,7 @@ __merge_sort(sycl::queue& __q, _Range&& __rng, _Compare __comp, _LeafSorter& __l
733728
}
734729

735730
template <typename _CustomName, typename _IndexT, typename _Range, typename _Compare>
736-
__future<sycl::event, std::shared_ptr<__result_and_scratch_storage_base>>
731+
__future<sycl::event, std::unique_ptr<__result_and_scratch_storage_base>>
737732
__submit_selecting_leaf(sycl::queue& __q, _Range&& __rng, _Compare __comp)
738733
{
739734
using _Leaf = __leaf_sorter<std::decay_t<_Range>, _Compare>;
@@ -786,7 +781,7 @@ __submit_selecting_leaf(sycl::queue& __q, _Range&& __rng, _Compare __comp)
786781
};
787782

788783
template <typename _ExecutionPolicy, typename _Range, typename _Compare>
789-
__future<sycl::event, std::shared_ptr<__result_and_scratch_storage_base>>
784+
__future<sycl::event, std::unique_ptr<__result_and_scratch_storage_base>>
790785
__parallel_sort_impl(oneapi::dpl::__internal::__device_backend_tag, _ExecutionPolicy&& __exec, _Range&& __rng,
791786
_Compare __comp)
792787
{

include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_utils.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -729,7 +729,7 @@ struct __deferrable_mode
729729
{
730730
};
731731

732-
//A contract for future class: <sycl::event or other event, a value, sycl::buffers..., or __usm_host_or_buffer_storage>
732+
//A contract for future class: <sycl::event or other event, a value, sycl::buffers..., or __result_and_scratch_storage>
733733
//Impl details: inheritance (private) instead of aggregation for enabling the empty base optimization.
734734
template <typename _Event, typename... _Args>
735735
class __future : private std::tuple<_Args...>

0 commit comments

Comments
 (0)