66#include " core/common/safeint.h"
77#include " core/common/status.h"
88#include " core/framework/allocator.h"
9+ #include " core/framework/error_code_helper.h"
910#include " core/mlas/inc/mlas.h"
1011#include " core/framework/utils.h"
1112#include " core/session/ort_apis.h"
@@ -185,22 +186,32 @@ std::ostream& operator<<(std::ostream& out, const OrtMemoryInfo& info) { return
185186#endif
186187ORT_API_STATUS_IMPL (OrtApis::CreateMemoryInfo, _In_ const char * name1, enum OrtAllocatorType type, int id1,
187188 enum OrtMemType mem_type1, _Outptr_ OrtMemoryInfo** out) {
189+ API_IMPL_BEGIN
190+
191+ if (name1 == nullptr ) {
192+ return OrtApis::CreateStatus (ORT_INVALID_ARGUMENT, " MemoryInfo name cannot be null." );
193+ }
194+
195+ if (out == nullptr ) {
196+ return OrtApis::CreateStatus (ORT_INVALID_ARGUMENT, " Output memory info cannot be null." );
197+ }
198+
188199 auto device_id = static_cast <OrtDevice::DeviceId>(id1);
189200 if (strcmp (name1, onnxruntime::CPU) == 0 ) {
190201 *out = new OrtMemoryInfo (onnxruntime::CPU, type, OrtDevice (), mem_type1);
191202 } else if (strcmp (name1, onnxruntime::CUDA) == 0 ) {
192203 *out = new OrtMemoryInfo (
193- name1 , type,
204+ onnxruntime::CUDA , type,
194205 OrtDevice (OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NVIDIA, device_id),
195206 mem_type1);
196207 } else if (strcmp (name1, onnxruntime::OpenVINO_GPU) == 0 ) {
197208 *out = new OrtMemoryInfo (
198- name1 , type,
209+ onnxruntime::OpenVINO_GPU , type,
199210 OrtDevice (OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::INTEL, device_id),
200211 mem_type1);
201212 } else if (strcmp (name1, onnxruntime::HIP) == 0 ) {
202213 *out = new OrtMemoryInfo (
203- name1 , type,
214+ onnxruntime::HIP , type,
204215 OrtDevice (OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::AMD, device_id),
205216 mem_type1);
206217 } else if (strcmp (name1, onnxruntime::WEBGPU_BUFFER) == 0 ||
@@ -212,45 +223,56 @@ ORT_API_STATUS_IMPL(OrtApis::CreateMemoryInfo, _In_ const char* name1, enum OrtA
212223
213224 } else if (strcmp (name1, onnxruntime::DML) == 0 ) {
214225 *out = new OrtMemoryInfo (
215- name1 , type,
226+ onnxruntime::DML , type,
216227 OrtDevice (OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::MICROSOFT, device_id),
217228 mem_type1);
218229 } else if (strcmp (name1, onnxruntime::OpenVINO_RT_NPU) == 0 ) {
219230 *out = new OrtMemoryInfo (
220- name1 , type,
231+ onnxruntime::OpenVINO_RT_NPU , type,
221232 OrtDevice (OrtDevice::NPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::INTEL, device_id),
222233 mem_type1);
223234 } else if (strcmp (name1, onnxruntime::CUDA_PINNED) == 0 ) {
224235 *out = new OrtMemoryInfo (
225- name1 , type,
236+ onnxruntime::CUDA_PINNED , type,
226237 OrtDevice (OrtDevice::GPU, OrtDevice::MemType::HOST_ACCESSIBLE, OrtDevice::VendorIds::NVIDIA, device_id),
227238 mem_type1);
228239 } else if (strcmp (name1, onnxruntime::HIP_PINNED) == 0 ) {
229240 *out = new OrtMemoryInfo (
230- name1 , type,
241+ onnxruntime::HIP_PINNED , type,
231242 OrtDevice (OrtDevice::GPU, OrtDevice::MemType::HOST_ACCESSIBLE, OrtDevice::VendorIds::AMD, device_id),
232243 mem_type1);
233244 } else if (strcmp (name1, onnxruntime::QNN_HTP_SHARED) == 0 ) {
234245 *out = new OrtMemoryInfo (
235- name1 , type,
246+ onnxruntime::QNN_HTP_SHARED , type,
236247 OrtDevice (OrtDevice::CPU, OrtDevice::MemType::HOST_ACCESSIBLE, OrtDevice::VendorIds::QUALCOMM, device_id),
237248 mem_type1);
238249 } else if (strcmp (name1, onnxruntime::CPU_ALIGNED_4K) == 0 ) {
239250 *out = new OrtMemoryInfo (
240- name1 , type,
251+ onnxruntime::CPU_ALIGNED_4K , type,
241252 OrtDevice (OrtDevice::CPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NONE, device_id,
242253 onnxruntime::kAlloc4KAlignment ),
243254 mem_type1);
244255 } else {
245256 return OrtApis::CreateStatus (ORT_INVALID_ARGUMENT, " Specified device is not supported. Try CreateMemoryInfo_V2." );
246257 }
258+ API_IMPL_END
247259 return nullptr ;
248260}
249261
250262ORT_API_STATUS_IMPL (OrtApis::CreateMemoryInfo_V2, _In_ const char * name, _In_ enum OrtMemoryInfoDeviceType device_type,
251263 _In_ uint32_t vendor_id, _In_ int32_t device_id, _In_ enum OrtDeviceMemoryType mem_type,
252264 _In_ size_t alignment, enum OrtAllocatorType type,
253265 _Outptr_ OrtMemoryInfo** out) {
266+ API_IMPL_BEGIN
267+
268+ if (name == nullptr ) {
269+ return OrtApis::CreateStatus (ORT_INVALID_ARGUMENT, " MemoryInfo name cannot be null." );
270+ }
271+
272+ if (out == nullptr ) {
273+ return OrtApis::CreateStatus (ORT_INVALID_ARGUMENT, " Output memory info cannot be null." );
274+ }
275+
254276 // map the public enum values to internal OrtDevice values
255277 OrtDevice::MemoryType mt = mem_type == OrtDeviceMemoryType_DEFAULT ? OrtDevice::MemType::DEFAULT
256278 : OrtDevice::MemType::HOST_ACCESSIBLE;
@@ -275,6 +297,7 @@ ORT_API_STATUS_IMPL(OrtApis::CreateMemoryInfo_V2, _In_ const char* name, _In_ en
275297
276298 *out = new OrtMemoryInfo (name, type, OrtDevice{dt, mt, vendor_id, narrow<int16_t >(device_id), alignment},
277299 mem_type == OrtDeviceMemoryType_DEFAULT ? OrtMemTypeDefault : OrtMemTypeCPU);
300+ API_IMPL_END
278301 return nullptr ;
279302}
280303
@@ -283,7 +306,7 @@ ORT_API(void, OrtApis::ReleaseMemoryInfo, _Frees_ptr_opt_ OrtMemoryInfo* p) { de
283306#pragma warning(pop)
284307#endif
285308ORT_API_STATUS_IMPL (OrtApis::MemoryInfoGetName, _In_ const OrtMemoryInfo* ptr, _Out_ const char ** out) {
286- *out = ptr->name ;
309+ *out = ptr->name . c_str () ;
287310 return nullptr ;
288311}
289312
0 commit comments