ONNX Runtime
Loading...
Searching...
No Matches
onnxruntime_cxx_api.h
1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4// Summary: The Ort C++ API is a header only wrapper around the Ort C API.
5//
6// The C++ API simplifies usage by returning values directly instead of error codes, throwing exceptions on errors
7// and automatically releasing resources in the destructors. The primary purpose of C++ API is exception safety so
8// all the resources follow RAII and do not leak memory.
9//
10// Each of the C++ wrapper classes holds only a pointer to the C internal object. Treat them like smart pointers.
11// To create an empty object, pass 'nullptr' to the constructor (for example, Env e{nullptr};). However, you can't use them
12// until you assign an instance that actually holds an underlying object.
13//
14// For Ort objects only move assignment between objects is allowed, there are no copy constructors.
15// Some objects have explicit 'Clone' methods for this purpose.
16//
17// ConstXXXX types are copyable since they do not own the underlying C object, so you can pass them to functions as arguments
18// by value or by reference. ConstXXXX types are restricted to const only interfaces.
19//
20// UnownedXXXX are similar to ConstXXXX but also allow non-const interfaces.
21//
22// The lifetime of the corresponding owning object must eclipse the lifetimes of the ConstXXXX/UnownedXXXX types. They exists so you do not
23// have to fallback to C types and the API with the usual pitfalls. In general, do not use C API from your C++ code.
24
25#pragma once
26#include "onnxruntime_c_api.h"
27#include <cstddef>
28#include <cstdio>
29#include <array>
30#include <memory>
31#include <stdexcept>
32#include <string>
33#include <vector>
34#include <unordered_map>
35#include <utility>
36#include <type_traits>
37
38#ifdef ORT_NO_EXCEPTIONS
39#include <iostream>
40#endif
41
45namespace Ort {
46
51struct Exception : std::exception {
52 Exception(std::string&& string, OrtErrorCode code) : message_{std::move(string)}, code_{code} {}
53
54 OrtErrorCode GetOrtErrorCode() const { return code_; }
55 const char* what() const noexcept override { return message_.c_str(); }
56
57 private:
58 std::string message_;
59 OrtErrorCode code_;
60};
61
62#ifdef ORT_NO_EXCEPTIONS
63// The #ifndef is for the very special case where the user of this library wants to define their own way of handling errors.
64// NOTE: This header expects control flow to not continue after calling ORT_CXX_API_THROW
65#ifndef ORT_CXX_API_THROW
66#define ORT_CXX_API_THROW(string, code) \
67 do { \
68 std::cerr << Ort::Exception(string, code) \
69 .what() \
70 << std::endl; \
71 abort(); \
72 } while (false)
73#endif
74#else
75#define ORT_CXX_API_THROW(string, code) \
76 throw Ort::Exception(string, code)
77#endif
78
79// This is used internally by the C++ API. This class holds the global variable that points to the OrtApi,
80// it's in a template so that we can define a global variable in a header and make
81// it transparent to the users of the API.
82template <typename T>
83struct Global {
84 static const OrtApi* api_;
85};
86
87// If macro ORT_API_MANUAL_INIT is defined, no static initialization will be performed. Instead, user must call InitApi() before using it.
88template <typename T>
89#ifdef ORT_API_MANUAL_INIT
90const OrtApi* Global<T>::api_{};
91inline void InitApi() noexcept { Global<void>::api_ = OrtGetApiBase()->GetApi(ORT_API_VERSION); }
92
93// Used by custom operator libraries that are not linked to onnxruntime. Sets the global API object, which is
94// required by C++ APIs.
95//
96// Example mycustomop.cc:
97//
98// #define ORT_API_MANUAL_INIT
99// #include <onnxruntime_cxx_api.h>
100// #undef ORT_API_MANUAL_INIT
101//
102// OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options, const OrtApiBase* api_base) {
103// Ort::InitApi(api_base->GetApi(ORT_API_VERSION));
104// // ...
105// }
106//
107inline void InitApi(const OrtApi* api) noexcept { Global<void>::api_ = api; }
108#else
109#if defined(_MSC_VER) && !defined(__clang__)
110#pragma warning(push)
111// "Global initializer calls a non-constexpr function." Therefore you can't use ORT APIs in the other global initializers.
112// Please define ORT_API_MANUAL_INIT if it conerns you.
113#pragma warning(disable : 26426)
114#endif
116#if defined(_MSC_VER) && !defined(__clang__)
117#pragma warning(pop)
118#endif
119#endif
120
122inline const OrtApi& GetApi() noexcept { return *Global<void>::api_; }
123
128std::string GetVersionString();
129
135std::string GetBuildInfoString();
136
142std::vector<std::string> GetAvailableProviders();
143
183struct Float16_t {
184 uint16_t value;
185 constexpr Float16_t() noexcept : value(0) {}
186 constexpr Float16_t(uint16_t v) noexcept : value(v) {}
187 constexpr operator uint16_t() const noexcept { return value; }
188 constexpr bool operator==(const Float16_t& rhs) const noexcept { return value == rhs.value; };
189 constexpr bool operator!=(const Float16_t& rhs) const noexcept { return value != rhs.value; };
190};
191
192static_assert(sizeof(Float16_t) == sizeof(uint16_t), "Sizes must match");
193
203 uint16_t value;
204 constexpr BFloat16_t() noexcept : value(0) {}
205 constexpr BFloat16_t(uint16_t v) noexcept : value(v) {}
206 constexpr operator uint16_t() const noexcept { return value; }
207 constexpr bool operator==(const BFloat16_t& rhs) const noexcept { return value == rhs.value; };
208 constexpr bool operator!=(const BFloat16_t& rhs) const noexcept { return value != rhs.value; };
209};
210
211static_assert(sizeof(BFloat16_t) == sizeof(uint16_t), "Sizes must match");
212
213namespace detail {
214// This is used internally by the C++ API. This macro is to make it easy to generate overloaded methods for all of the various OrtRelease* functions for every Ort* type
215// This can't be done in the C API since C doesn't have function overloading.
216#define ORT_DEFINE_RELEASE(NAME) \
217 inline void OrtRelease(Ort##NAME* ptr) { GetApi().Release##NAME(ptr); }
218
219ORT_DEFINE_RELEASE(Allocator);
220ORT_DEFINE_RELEASE(MemoryInfo);
221ORT_DEFINE_RELEASE(CustomOpDomain);
222ORT_DEFINE_RELEASE(ThreadingOptions);
223ORT_DEFINE_RELEASE(Env);
224ORT_DEFINE_RELEASE(RunOptions);
225ORT_DEFINE_RELEASE(Session);
226ORT_DEFINE_RELEASE(SessionOptions);
227ORT_DEFINE_RELEASE(TensorTypeAndShapeInfo);
228ORT_DEFINE_RELEASE(SequenceTypeInfo);
229ORT_DEFINE_RELEASE(MapTypeInfo);
230ORT_DEFINE_RELEASE(TypeInfo);
231ORT_DEFINE_RELEASE(Value);
232ORT_DEFINE_RELEASE(ModelMetadata);
233ORT_DEFINE_RELEASE(IoBinding);
234ORT_DEFINE_RELEASE(ArenaCfg);
235ORT_DEFINE_RELEASE(Status);
236ORT_DEFINE_RELEASE(OpAttr);
237ORT_DEFINE_RELEASE(Op);
238ORT_DEFINE_RELEASE(KernelInfo);
239
240#undef ORT_DEFINE_RELEASE
241
245template <typename T>
246struct Unowned {
247 using Type = T;
248};
249
269template <typename T>
270struct Base {
271 using contained_type = T;
272
273 constexpr Base() = default;
274 constexpr explicit Base(contained_type* p) noexcept : p_{p} {}
276
277 Base(const Base&) = delete;
278 Base& operator=(const Base&) = delete;
279
280 Base(Base&& v) noexcept : p_{v.p_} { v.p_ = nullptr; }
281 Base& operator=(Base&& v) noexcept {
282 OrtRelease(p_);
283 p_ = v.release();
284 return *this;
285 }
286
287 constexpr operator contained_type*() const noexcept { return p_; }
288
292 T* p = p_;
293 p_ = nullptr;
294 return p;
295 }
296
297 protected:
299};
300
301// Undefined. For const types use Base<Unowned<const T>>
302template <typename T>
303struct Base<const T>;
304
312template <typename T>
313struct Base<Unowned<T>> {
315
316 constexpr Base() = default;
317 constexpr explicit Base(contained_type* p) noexcept : p_{p} {}
318
319 ~Base() = default;
320
321 Base(const Base&) = default;
322 Base& operator=(const Base&) = default;
323
324 Base(Base&& v) noexcept : p_{v.p_} { v.p_ = nullptr; }
325 Base& operator=(Base&& v) noexcept {
326 p_ = nullptr;
327 std::swap(p_, v.p_);
328 return *this;
329 }
330
331 constexpr operator contained_type*() const noexcept { return p_; }
332
333 protected:
335};
336
337// Light functor to release memory with OrtAllocator
340 explicit AllocatedFree(OrtAllocator* allocator)
341 : allocator_(allocator) {}
342 void operator()(void* ptr) const {
343 if (ptr) allocator_->Free(allocator_, ptr);
344 }
345};
346
347} // namespace detail
348
349struct AllocatorWithDefaultOptions;
350struct Env;
351struct TypeInfo;
352struct Value;
353struct ModelMetadata;
354
359using AllocatedStringPtr = std::unique_ptr<char, detail::AllocatedFree>;
360
365struct Status : detail::Base<OrtStatus> {
366 explicit Status(std::nullptr_t) noexcept {}
367 explicit Status(OrtStatus* status) noexcept;
368 explicit Status(const Exception&) noexcept;
369 explicit Status(const std::exception&) noexcept;
370 Status(const char* message, OrtErrorCode code) noexcept;
371 std::string GetErrorMessage() const;
373 bool IsOK() const noexcept;
374};
375
383
386
389
392
395
398
400 ThreadingOptions& SetGlobalCustomThreadCreationOptions(void* ort_custom_thread_creation_options);
401
404};
405
411struct Env : detail::Base<OrtEnv> {
412 explicit Env(std::nullptr_t) {}
413
415 Env(OrtLoggingLevel logging_level = ORT_LOGGING_LEVEL_WARNING, _In_ const char* logid = "");
416
418 Env(OrtLoggingLevel logging_level, const char* logid, OrtLoggingFunction logging_function, void* logger_param);
419
421 Env(const OrtThreadingOptions* tp_options, OrtLoggingLevel logging_level = ORT_LOGGING_LEVEL_WARNING, _In_ const char* logid = "");
422
424 Env(const OrtThreadingOptions* tp_options, OrtLoggingFunction logging_function, void* logger_param,
425 OrtLoggingLevel logging_level = ORT_LOGGING_LEVEL_WARNING, _In_ const char* logid = "");
426
428 explicit Env(OrtEnv* p) : Base<OrtEnv>{p} {}
429
432
434
435 Env& CreateAndRegisterAllocator(const OrtMemoryInfo* mem_info, const OrtArenaCfg* arena_cfg);
436};
437
441struct CustomOpDomain : detail::Base<OrtCustomOpDomain> {
442 explicit CustomOpDomain(std::nullptr_t) {}
443
445 explicit CustomOpDomain(const char* domain);
446
447 // This does not take ownership of the op, simply registers it.
448 void Add(const OrtCustomOp* op);
449};
450
454struct RunOptions : detail::Base<OrtRunOptions> {
455 explicit RunOptions(std::nullptr_t) {}
457
460
463
464 RunOptions& SetRunTag(const char* run_tag);
465 const char* GetRunTag() const;
466
467 RunOptions& AddConfigEntry(const char* config_key, const char* config_value);
468
475
481};
482
483namespace detail {
484// Utility function that returns a SessionOption config entry key for a specific custom operator.
485// Ex: custom_op.[custom_op_name].[config]
486std::string MakeCustomOpConfigEntryKey(const char* custom_op_name, const char* config);
487} // namespace detail
488
499 CustomOpConfigs() = default;
500 ~CustomOpConfigs() = default;
505
514 CustomOpConfigs& AddConfig(const char* custom_op_name, const char* config_key, const char* config_value);
515
524 const std::unordered_map<std::string, std::string>& GetFlattenedConfigs() const;
525
526 private:
527 std::unordered_map<std::string, std::string> flat_configs_;
528};
529
535struct SessionOptions;
536
537namespace detail {
538// we separate const-only methods because passing const ptr to non-const methods
539// is only discovered when inline methods are compiled which is counter-intuitive
540template <typename T>
542 using B = Base<T>;
543 using B::B;
544
546
547 std::string GetConfigEntry(const char* config_key) const;
548 bool HasConfigEntry(const char* config_key) const;
549 std::string GetConfigEntryOrDefault(const char* config_key, const std::string& def);
550};
551
552template <typename T>
555 using B::B;
556
557 SessionOptionsImpl& SetIntraOpNumThreads(int intra_op_num_threads);
558 SessionOptionsImpl& SetInterOpNumThreads(int inter_op_num_threads);
560
563
564 SessionOptionsImpl& SetOptimizedModelFilePath(const ORTCHAR_T* optimized_model_file);
565
566 SessionOptionsImpl& EnableProfiling(const ORTCHAR_T* profile_file_prefix);
568
570
573
575
576 SessionOptionsImpl& SetLogId(const char* logid);
578
580
582
583 SessionOptionsImpl& AddConfigEntry(const char* config_key, const char* config_value);
584
585 SessionOptionsImpl& AddInitializer(const char* name, const OrtValue* ort_val);
586 SessionOptionsImpl& AddExternalInitializers(const std::vector<std::string>& names, const std::vector<Value>& ort_values);
587
600 SessionOptionsImpl& AppendExecutionProvider(const std::string& provider_name,
601 const std::unordered_map<std::string, std::string>& provider_options = {});
602
604 SessionOptionsImpl& SetCustomThreadCreationOptions(void* ort_custom_thread_creation_options);
606
610 SessionOptionsImpl& RegisterCustomOpsLibrary(const ORTCHAR_T* library_name, const CustomOpConfigs& custom_op_configs = {});
611
613};
614} // namespace detail
615
618
622struct SessionOptions : detail::SessionOptionsImpl<OrtSessionOptions> {
623 explicit SessionOptions(std::nullptr_t) {}
625 explicit SessionOptions(OrtSessionOptions* p) : SessionOptionsImpl<OrtSessionOptions>{p} {}
628};
629
633struct ModelMetadata : detail::Base<OrtModelMetadata> {
634 explicit ModelMetadata(std::nullptr_t) {}
636
644
652
660
668
676
683 std::vector<AllocatedStringPtr> GetCustomMetadataMapKeysAllocated(OrtAllocator* allocator) const;
684
695
696 int64_t GetVersion() const;
697};
698
699struct IoBinding;
700
701namespace detail {
702
703// we separate const-only methods because passing const ptr to non-const methods
704// is only discovered when inline methods are compiled which is counter-intuitive
705template <typename T>
707 using B = Base<T>;
708 using B::B;
709
710 size_t GetInputCount() const;
711 size_t GetOutputCount() const;
713
722
731
740
741 uint64_t GetProfilingStartTimeNs() const;
743
744 TypeInfo GetInputTypeInfo(size_t index) const;
745 TypeInfo GetOutputTypeInfo(size_t index) const;
747};
748
749template <typename T>
752 using B::B;
753
771 std::vector<Value> Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
772 const char* const* output_names, size_t output_count);
773
777 void Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
778 const char* const* output_names, Value* output_values, size_t output_count);
779
780 void Run(const RunOptions& run_options, const IoBinding&);
781
789};
790
791} // namespace detail
792
795
799struct Session : detail::SessionImpl<OrtSession> {
800 explicit Session(std::nullptr_t) {}
801 Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options);
802 Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options,
803 OrtPrepackedWeightsContainer* prepacked_weights_container);
804 Session(const Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options);
805 Session(const Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options,
806 OrtPrepackedWeightsContainer* prepacked_weights_container);
807
808 ConstSession GetConst() const { return ConstSession{this->p_}; }
809 UnownedSession GetUnowned() const { return UnownedSession{this->p_}; }
810};
811
812namespace detail {
813template <typename T>
814struct MemoryInfoImpl : Base<T> {
815 using B = Base<T>;
816 using B::B;
817
818 std::string GetAllocatorName() const;
820 int GetDeviceId() const;
823
824 template <typename U>
825 bool operator==(const MemoryInfoImpl<U>& o) const;
826};
827} // namespace detail
828
829// Const object holder that does not own the underlying object
831
835struct MemoryInfo : detail::MemoryInfoImpl<OrtMemoryInfo> {
837 explicit MemoryInfo(std::nullptr_t) {}
838 explicit MemoryInfo(OrtMemoryInfo* p) : MemoryInfoImpl<OrtMemoryInfo>{p} {}
839 MemoryInfo(const char* name, OrtAllocatorType type, int id, OrtMemType mem_type);
840 ConstMemoryInfo GetConst() const { return ConstMemoryInfo{this->p_}; }
841};
842
843namespace detail {
844template <typename T>
846 using B = Base<T>;
847 using B::B;
848
850 size_t GetElementCount() const;
851
852 size_t GetDimensionsCount() const;
853
858 [[deprecated("use GetShape()")]] void GetDimensions(int64_t* values, size_t values_count) const;
859
860 void GetSymbolicDimensions(const char** values, size_t values_count) const;
861
862 std::vector<int64_t> GetShape() const;
863};
864
865} // namespace detail
866
868
873 explicit TensorTypeAndShapeInfo(std::nullptr_t) {}
874 explicit TensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* p) : TensorTypeAndShapeInfoImpl{p} {}
876};
877
878namespace detail {
879template <typename T>
881 using B = Base<T>;
882 using B::B;
884};
885
886} // namespace detail
887
889
893struct SequenceTypeInfo : detail::SequenceTypeInfoImpl<OrtSequenceTypeInfo> {
894 explicit SequenceTypeInfo(std::nullptr_t) {}
895 explicit SequenceTypeInfo(OrtSequenceTypeInfo* p) : SequenceTypeInfoImpl<OrtSequenceTypeInfo>{p} {}
897};
898
899namespace detail {
900template <typename T>
902 using B = Base<T>;
903 using B::B;
905};
906
907} // namespace detail
908
909// This is always owned by the TypeInfo and can only be obtained from it.
911
912namespace detail {
913template <typename T>
915 using B = Base<T>;
916 using B::B;
919};
920
921} // namespace detail
922
924
928struct MapTypeInfo : detail::MapTypeInfoImpl<OrtMapTypeInfo> {
929 explicit MapTypeInfo(std::nullptr_t) {}
930 explicit MapTypeInfo(OrtMapTypeInfo* p) : MapTypeInfoImpl<OrtMapTypeInfo>{p} {}
931 ConstMapTypeInfo GetConst() const { return ConstMapTypeInfo{this->p_}; }
932};
933
934namespace detail {
935template <typename T>
937 using B = Base<T>;
938 using B::B;
939
944
946};
947} // namespace detail
948
954
959struct TypeInfo : detail::TypeInfoImpl<OrtTypeInfo> {
960 explicit TypeInfo(std::nullptr_t) {}
961 explicit TypeInfo(OrtTypeInfo* p) : TypeInfoImpl<OrtTypeInfo>{p} {}
962
963 ConstTypeInfo GetConst() const { return ConstTypeInfo{this->p_}; }
964};
965
966namespace detail {
967// This structure is used to feed sparse tensor values
968// information for use with FillSparseTensor<Format>() API
969// if the data type for the sparse tensor values is numeric
970// use data.p_data, otherwise, use data.str pointer to feed
971// values. data.str is an array of const char* that are zero terminated.
972// number of strings in the array must match shape size.
973// For fully sparse tensors use shape {0} and set p_data/str
974// to nullptr.
976 const int64_t* values_shape;
978 union {
979 const void* p_data;
980 const char** str;
981 } data;
982};
983
984// Provides a way to pass shape in a single
985// argument
986struct Shape {
987 const int64_t* shape;
988 size_t shape_len;
989};
990
991template <typename T>
992struct ConstValueImpl : Base<T> {
993 using B = Base<T>;
994 using B::B;
995
999 template <typename R>
1000 void GetOpaqueData(const char* domain, const char* type_name, R&) const;
1001
1002 bool IsTensor() const;
1003 bool HasValue() const;
1004
1005 size_t GetCount() const; // If a non tensor, returns 2 for map and N for sequence, where N is the number of elements
1006 Value GetValue(int index, OrtAllocator* allocator) const;
1007
1015
1030 void GetStringTensorContent(void* buffer, size_t buffer_length, size_t* offsets, size_t offsets_count) const;
1031
1038 template <typename R>
1039 const R* GetTensorData() const;
1040
1045 const void* GetTensorRawData() const;
1046
1054
1062
1068
1077 void GetStringTensorElement(size_t buffer_length, size_t element_index, void* buffer) const;
1078
1085 std::string GetStringTensorElement(size_t element_index) const;
1086
1093 size_t GetStringTensorElementLength(size_t element_index) const;
1094
1095#if !defined(DISABLE_SPARSE_TENSORS)
1103
1110
1119
1129 template <typename R>
1130 const R* GetSparseTensorIndicesData(OrtSparseIndicesFormat indices_format, size_t& num_indices) const;
1131
1136 bool IsSparseTensor() const;
1137
1146 template <typename R>
1147 const R* GetSparseTensorValues() const;
1148
1149#endif
1150};
1151
1152template <typename T>
1155 using B::B;
1156
1162 template <typename R>
1164
1170
1172 // Obtain a reference to an element of data at the location specified
1178 template <typename R>
1179 R& At(const std::vector<int64_t>& location);
1180
1186 void FillStringTensor(const char* const* s, size_t s_len);
1187
1193 void FillStringTensorElement(const char* s, size_t index);
1194
1207 char* GetResizedStringTensorElementBuffer(size_t index, size_t buffer_length);
1208
1209#if !defined(DISABLE_SPARSE_TENSORS)
1218 void UseCooIndices(int64_t* indices_data, size_t indices_num);
1219
1230 void UseCsrIndices(int64_t* inner_data, size_t inner_num, int64_t* outer_data, size_t outer_num);
1231
1240 void UseBlockSparseIndices(const Shape& indices_shape, int32_t* indices_data);
1241
1251 void FillSparseTensorCoo(const OrtMemoryInfo* data_mem_info, const OrtSparseValuesParam& values_param,
1252 const int64_t* indices_data, size_t indices_num);
1253
1265 void FillSparseTensorCsr(const OrtMemoryInfo* data_mem_info,
1266 const OrtSparseValuesParam& values,
1267 const int64_t* inner_indices_data, size_t inner_indices_num,
1268 const int64_t* outer_indices_data, size_t outer_indices_num);
1269
1280 const OrtSparseValuesParam& values,
1281 const Shape& indices_shape,
1282 const int32_t* indices_data);
1283
1284#endif
1285};
1286
1287} // namespace detail
1288
1291
1295struct Value : detail::ValueImpl<OrtValue> {
1299
1300 explicit Value(std::nullptr_t) {}
1301 explicit Value(OrtValue* p) : Base{p} {}
1302 Value(Value&&) = default;
1303 Value& operator=(Value&&) = default;
1304
1305 ConstValue GetConst() const { return ConstValue{this->p_}; }
1306 UnownedValue GetUnowned() const { return UnownedValue{this->p_}; }
1307
1316 template <typename T>
1317 static Value CreateTensor(const OrtMemoryInfo* info, T* p_data, size_t p_data_element_count, const int64_t* shape, size_t shape_len);
1318
1327 static Value CreateTensor(const OrtMemoryInfo* info, void* p_data, size_t p_data_byte_count, const int64_t* shape, size_t shape_len,
1329
1336 template <typename T>
1337 static Value CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len);
1338
1345 static Value CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type);
1346
1347 static Value CreateMap(Value& keys, Value& values);
1348 static Value CreateSequence(std::vector<Value>& values);
1349
1350 template <typename T>
1351 static Value CreateOpaque(const char* domain, const char* type_name, const T&);
1352
1353#if !defined(DISABLE_SPARSE_TENSORS)
1364 template <typename T>
1365 static Value CreateSparseTensor(const OrtMemoryInfo* info, T* p_data, const Shape& dense_shape,
1366 const Shape& values_shape);
1367
1384 static Value CreateSparseTensor(const OrtMemoryInfo* info, void* p_data, const Shape& dense_shape,
1385 const Shape& values_shape, ONNXTensorElementDataType type);
1386
1396 template <typename T>
1397 static Value CreateSparseTensor(OrtAllocator* allocator, const Shape& dense_shape);
1398
1410 static Value CreateSparseTensor(OrtAllocator* allocator, const Shape& dense_shape, ONNXTensorElementDataType type);
1411
1412#endif // !defined(DISABLE_SPARSE_TENSORS)
1413};
1414
1422 MemoryAllocation(OrtAllocator* allocator, void* p, size_t size);
1427 MemoryAllocation& operator=(MemoryAllocation&&) noexcept;
1428
1429 void* get() { return p_; }
1430 size_t size() const { return size_; }
1431
1432 private:
1433 OrtAllocator* allocator_;
1434 void* p_;
1435 size_t size_;
1436};
1437
1438namespace detail {
1439template <typename T>
1440struct AllocatorImpl : Base<T> {
1441 using B = Base<T>;
1442 using B::B;
1443
1444 void* Alloc(size_t size);
1446 void Free(void* p);
1448};
1449
1450} // namespace detail
1451
1455struct AllocatorWithDefaultOptions : detail::AllocatorImpl<detail::Unowned<OrtAllocator>> {
1456 explicit AllocatorWithDefaultOptions(std::nullptr_t) {}
1458};
1459
1463struct Allocator : detail::AllocatorImpl<OrtAllocator> {
1464 explicit Allocator(std::nullptr_t) {}
1465 Allocator(const Session& session, const OrtMemoryInfo*);
1466};
1467
1469
1470namespace detail {
1471namespace binding_utils {
1472// Bring these out of template
1473std::vector<std::string> GetOutputNamesHelper(const OrtIoBinding* binding, OrtAllocator*);
1474std::vector<Value> GetOutputValuesHelper(const OrtIoBinding* binding, OrtAllocator*);
1475} // namespace binding_utils
1476
1477template <typename T>
1479 using B = Base<T>;
1480 using B::B;
1481
1482 std::vector<std::string> GetOutputNames() const;
1483 std::vector<std::string> GetOutputNames(OrtAllocator*) const;
1484 std::vector<Value> GetOutputValues() const;
1485 std::vector<Value> GetOutputValues(OrtAllocator*) const;
1486};
1487
1488template <typename T>
1491 using B::B;
1492
1493 void BindInput(const char* name, const Value&);
1494 void BindOutput(const char* name, const Value&);
1495 void BindOutput(const char* name, const OrtMemoryInfo*);
1500};
1501
1502} // namespace detail
1503
1506
1510struct IoBinding : detail::IoBindingImpl<OrtIoBinding> {
1511 explicit IoBinding(std::nullptr_t) {}
1512 explicit IoBinding(Session& session);
1513 ConstIoBinding GetConst() const { return ConstIoBinding{this->p_}; }
1514 UnownedIoBinding GetUnowned() const { return UnownedIoBinding{this->p_}; }
1515};
1516
1521struct ArenaCfg : detail::Base<OrtArenaCfg> {
1522 explicit ArenaCfg(std::nullptr_t) {}
1531 ArenaCfg(size_t max_mem, int arena_extend_strategy, int initial_chunk_size_bytes, int max_dead_bytes_per_chunk);
1532};
1533
1534//
1535// Custom OPs (only needed to implement custom OPs)
1536//
1537
1541struct OpAttr : detail::Base<OrtOpAttr> {
1542 OpAttr(const char* name, const void* data, int len, OrtOpAttrType type);
1543};
1544
1553#define ORT_CXX_LOG(logger, message_severity, message) \
1554 do { \
1555 if (message_severity >= logger.GetLoggingSeverityLevel()) { \
1556 Ort::ThrowOnError(logger.LogMessage(message_severity, ORT_FILE, __LINE__, \
1557 static_cast<const char*>(__FUNCTION__), message)); \
1558 } \
1559 } while (false)
1560
1569#define ORT_CXX_LOG_NOEXCEPT(logger, message_severity, message) \
1570 do { \
1571 if (message_severity >= logger.GetLoggingSeverityLevel()) { \
1572 static_cast<void>(logger.LogMessage(message_severity, ORT_FILE, __LINE__, \
1573 static_cast<const char*>(__FUNCTION__), message)); \
1574 } \
1575 } while (false)
1576
1588#define ORT_CXX_LOGF(logger, message_severity, /*format,*/...) \
1589 do { \
1590 if (message_severity >= logger.GetLoggingSeverityLevel()) { \
1591 Ort::ThrowOnError(logger.LogFormattedMessage(message_severity, ORT_FILE, __LINE__, \
1592 static_cast<const char*>(__FUNCTION__), __VA_ARGS__)); \
1593 } \
1594 } while (false)
1595
1607#define ORT_CXX_LOGF_NOEXCEPT(logger, message_severity, /*format,*/...) \
1608 do { \
1609 if (message_severity >= logger.GetLoggingSeverityLevel()) { \
1610 static_cast<void>(logger.LogFormattedMessage(message_severity, ORT_FILE, __LINE__, \
1611 static_cast<const char*>(__FUNCTION__), __VA_ARGS__)); \
1612 } \
1613 } while (false)
1614
1625struct Logger {
1629 Logger() = default;
1630
1634 explicit Logger(std::nullptr_t) {}
1635
1642 explicit Logger(const OrtLogger* logger);
1643
1644 ~Logger() = default;
1645
1646 Logger(const Logger&) = default;
1647 Logger& operator=(const Logger&) = default;
1648
1649 Logger(Logger&& v) noexcept = default;
1650 Logger& operator=(Logger&& v) noexcept = default;
1651
1658
1671 Status LogMessage(OrtLoggingLevel log_severity_level, const ORTCHAR_T* file_path, int line_number,
1672 const char* func_name, const char* message) const noexcept;
1673
1688 template <typename... Args>
1689 Status LogFormattedMessage(OrtLoggingLevel log_severity_level, const ORTCHAR_T* file_path, int line_number,
1690 const char* func_name, const char* format, Args&&... args) const noexcept;
1691
1692 private:
1693 const OrtLogger* logger_{};
1694 OrtLoggingLevel cached_severity_level_{};
1695};
1696
1705 size_t GetInputCount() const;
1706 size_t GetOutputCount() const;
1707 ConstValue GetInput(size_t index) const;
1708 UnownedValue GetOutput(size_t index, const int64_t* dim_values, size_t dim_count) const;
1709 UnownedValue GetOutput(size_t index, const std::vector<int64_t>& dims) const;
1710 void* GetGPUComputeStream() const;
1712 OrtAllocator* GetAllocator(const OrtMemoryInfo& memory_info) const;
1713
1714 private:
1715 OrtKernelContext* ctx_;
1716};
1717
1718struct KernelInfo;
1719
1720namespace detail {
1721namespace attr_utils {
1722void GetAttr(const OrtKernelInfo* p, const char* name, float&);
1723void GetAttr(const OrtKernelInfo* p, const char* name, int64_t&);
1724void GetAttr(const OrtKernelInfo* p, const char* name, std::string&);
1725void GetAttrs(const OrtKernelInfo* p, const char* name, std::vector<float>&);
1726void GetAttrs(const OrtKernelInfo* p, const char* name, std::vector<int64_t>&);
1727} // namespace attr_utils
1728
1729template <typename T>
1731 using B = Base<T>;
1732 using B::B;
1733
1735
1736 template <typename R> // R is only implemented for float, int64_t, and string
1737 R GetAttribute(const char* name) const {
1738 R val;
1739 attr_utils::GetAttr(this->p_, name, val);
1740 return val;
1741 }
1742
1743 template <typename R> // R is only implemented for std::vector<float>, std::vector<int64_t>
1744 std::vector<R> GetAttributes(const char* name) const {
1745 std::vector<R> result;
1746 attr_utils::GetAttrs(this->p_, name, result);
1747 return result;
1748 }
1749
1750 Value GetTensorAttribute(const char* name, OrtAllocator* allocator) const;
1751
1752 size_t GetInputCount() const;
1753 size_t GetOutputCount() const;
1754
1755 std::string GetInputName(size_t index) const;
1756 std::string GetOutputName(size_t index) const;
1757
1758 TypeInfo GetInputTypeInfo(size_t index) const;
1759 TypeInfo GetOutputTypeInfo(size_t index) const;
1760
1761 ConstValue GetTensorConstantInput(size_t index, int* is_constant) const;
1762
1763 std::string GetNodeName() const;
1765};
1766
1767} // namespace detail
1768
1770
1777struct KernelInfo : detail::KernelInfoImpl<OrtKernelInfo> {
1778 explicit KernelInfo(std::nullptr_t) {}
1779 explicit KernelInfo(OrtKernelInfo* info);
1780 ConstKernelInfo GetConst() const { return ConstKernelInfo{this->p_}; }
1781};
1782
1786struct Op : detail::Base<OrtOp> {
1787 explicit Op(std::nullptr_t) {}
1788
1789 explicit Op(OrtOp*);
1790
1791 static Op Create(const OrtKernelInfo* info, const char* op_name, const char* domain,
1792 int version, const char** type_constraint_names,
1793 const ONNXTensorElementDataType* type_constraint_values,
1794 size_t type_constraint_count,
1795 const OpAttr* attr_values,
1796 size_t attr_count,
1797 size_t input_count, size_t output_count);
1798
1799 void Invoke(const OrtKernelContext* context,
1800 const Value* input_values,
1801 size_t input_count,
1802 Value* output_values,
1803 size_t output_count);
1804
1805 // For easier refactoring
1806 void Invoke(const OrtKernelContext* context,
1807 const OrtValue* const* input_values,
1808 size_t input_count,
1809 OrtValue* const* output_values,
1810 size_t output_count);
1811};
1812
1818 CustomOpApi(const OrtApi& api) : api_(api) {}
1819
1824 [[deprecated("use Ort::Value::GetTensorTypeAndShape()")]] OrtTensorTypeAndShapeInfo* GetTensorTypeAndShape(_In_ const OrtValue* value);
1825
1830 [[deprecated("use Ort::TensorTypeAndShapeInfo::GetElementCount()")]] size_t GetTensorShapeElementCount(_In_ const OrtTensorTypeAndShapeInfo* info);
1831
1836 [[deprecated("use Ort::TensorTypeAndShapeInfo::GetElementType()")]] ONNXTensorElementDataType GetTensorElementType(const OrtTensorTypeAndShapeInfo* info);
1837
1842 [[deprecated("use Ort::TensorTypeAndShapeInfo::GetDimensionsCount()")]] size_t GetDimensionsCount(_In_ const OrtTensorTypeAndShapeInfo* info);
1843
1848 [[deprecated("use Ort::TensorTypeAndShapeInfo::GetShape()")]] void GetDimensions(_In_ const OrtTensorTypeAndShapeInfo* info, _Out_ int64_t* dim_values, size_t dim_values_length);
1849
1854 [[deprecated("Do not use")]] void SetDimensions(OrtTensorTypeAndShapeInfo* info, _In_ const int64_t* dim_values, size_t dim_count);
1855
1860 template <typename T>
1861 [[deprecated("use Ort::Value::GetTensorMutableData()")]] T* GetTensorMutableData(_Inout_ OrtValue* value);
1862
1867 template <typename T>
1868 [[deprecated("use Ort::Value::GetTensorData()")]] const T* GetTensorData(_Inout_ const OrtValue* value);
1869
1874 [[deprecated("use Ort::Value::GetTensorMemoryInfo()")]] const OrtMemoryInfo* GetTensorMemoryInfo(_In_ const OrtValue* value);
1875
1880 [[deprecated("use Ort::TensorTypeAndShapeInfo::GetShape()")]] std::vector<int64_t> GetTensorShape(const OrtTensorTypeAndShapeInfo* info);
1881
1886 [[deprecated("use TensorTypeAndShapeInfo")]] void ReleaseTensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* input);
1887
1892 [[deprecated("use Ort::KernelContext::GetInputCount")]] size_t KernelContext_GetInputCount(const OrtKernelContext* context);
1893
1898 [[deprecated("use Ort::KernelContext::GetInput")]] const OrtValue* KernelContext_GetInput(const OrtKernelContext* context, _In_ size_t index);
1899
1904 [[deprecated("use Ort::KernelContext::GetOutputCount")]] size_t KernelContext_GetOutputCount(const OrtKernelContext* context);
1905
1910 [[deprecated("use Ort::KernelContext::GetOutput")]] OrtValue* KernelContext_GetOutput(OrtKernelContext* context, _In_ size_t index, _In_ const int64_t* dim_values, size_t dim_count);
1911
1916 [[deprecated("use Ort::KernelContext::GetGPUComputeStream")]] void* KernelContext_GetGPUComputeStream(const OrtKernelContext* context);
1917
1922 [[deprecated("use Ort::ThrowOnError()")]] void ThrowOnError(OrtStatus* result);
1923
1928 [[deprecated("use Ort::OpAttr")]] OrtOpAttr* CreateOpAttr(_In_ const char* name,
1929 _In_ const void* data,
1930 _In_ int len,
1931 _In_ OrtOpAttrType type);
1932
1937 [[deprecated("use Ort::OpAttr")]] void ReleaseOpAttr(_Frees_ptr_opt_ OrtOpAttr* op_attr);
1938
1943 [[deprecated("use Ort::Op")]] OrtOp* CreateOp(_In_ const OrtKernelInfo* info,
1944 _In_z_ const char* op_name,
1945 _In_z_ const char* domain,
1946 int version,
1947 _In_reads_(type_constraint_count) const char** type_constraint_names,
1948 _In_reads_(type_constraint_count) const ONNXTensorElementDataType* type_constraint_values,
1949 int type_constraint_count,
1950 _In_reads_(attr_count) const OrtOpAttr* const* attr_values,
1951 int attr_count,
1952 int input_count,
1953 int output_count);
1954
1959 [[deprecated("use Ort::Op::Invoke")]] void InvokeOp(_In_ const OrtKernelContext* context,
1960 _In_ const OrtOp* ort_op,
1961 _In_ const OrtValue* const* input_values,
1962 _In_ int input_count,
1963 _Inout_ OrtValue* const* output_values,
1964 _In_ int output_count);
1965
1970 [[deprecated("use Ort::Op")]] void ReleaseOp(_Frees_ptr_opt_ OrtOp* ort_op);
1971
1977 template <typename T> // T is only implemented for std::vector<float>, std::vector<int64_t>, float, int64_t, and string
1978 [[deprecated("use Ort::KernelInfo::GetAttribute")]] T KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name);
1979
1985 [[deprecated("use Ort::KernelInfo::Copy")]] OrtKernelInfo* CopyKernelInfo(_In_ const OrtKernelInfo* info);
1986
1992 [[deprecated("use Ort::KernelInfo")]] void ReleaseKernelInfo(_Frees_ptr_opt_ OrtKernelInfo* info_copy);
1993
1994 private:
1995 const OrtApi& api_;
1996};
1997
1998template <typename TOp, typename TKernel>
2002 OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* api, const OrtKernelInfo* info) { return static_cast<const TOp*>(this_)->CreateKernel(*api, info); };
2003 OrtCustomOp::GetName = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetName(); };
2004
2005 OrtCustomOp::GetExecutionProviderType = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetExecutionProviderType(); };
2006
2007 OrtCustomOp::GetInputTypeCount = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetInputTypeCount(); };
2008 OrtCustomOp::GetInputType = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetInputType(index); };
2009 OrtCustomOp::GetInputMemoryType = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetInputMemoryType(index); };
2010
2011 OrtCustomOp::GetOutputTypeCount = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetOutputTypeCount(); };
2012 OrtCustomOp::GetOutputType = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetOutputType(index); };
2013
2014 OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) { static_cast<TKernel*>(op_kernel)->Compute(context); };
2015#if defined(_MSC_VER) && !defined(__clang__)
2016#pragma warning(push)
2017#pragma warning(disable : 26409)
2018#endif
2019 OrtCustomOp::KernelDestroy = [](void* op_kernel) { delete static_cast<TKernel*>(op_kernel); };
2020#if defined(_MSC_VER) && !defined(__clang__)
2021#pragma warning(pop)
2022#endif
2023 OrtCustomOp::GetInputCharacteristic = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetInputCharacteristic(index); };
2024 OrtCustomOp::GetOutputCharacteristic = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetOutputCharacteristic(index); };
2025
2026 OrtCustomOp::GetVariadicInputMinArity = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetVariadicInputMinArity(); };
2027 OrtCustomOp::GetVariadicInputHomogeneity = [](const OrtCustomOp* this_) { return static_cast<int>(static_cast<const TOp*>(this_)->GetVariadicInputHomogeneity()); };
2028 OrtCustomOp::GetVariadicOutputMinArity = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetVariadicOutputMinArity(); };
2029 OrtCustomOp::GetVariadicOutputHomogeneity = [](const OrtCustomOp* this_) { return static_cast<int>(static_cast<const TOp*>(this_)->GetVariadicOutputHomogeneity()); };
2030 }
2031
2032 // Default implementation of GetExecutionProviderType that returns nullptr to default to the CPU provider
2033 const char* GetExecutionProviderType() const { return nullptr; }
2034
2035 // Default implementations of GetInputCharacteristic() and GetOutputCharacteristic() below
2036 // (inputs and outputs are required by default)
2038 return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED;
2039 }
2040
2042 return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED;
2043 }
2044
2045 // Default implemention of GetInputMemoryType() that returns OrtMemTypeDefault
2046 OrtMemType GetInputMemoryType(size_t /*index*/) const {
2047 return OrtMemTypeDefault;
2048 }
2049
2050 // Default implementation of GetVariadicInputMinArity() returns 1 to specify that a variadic input
2051 // should expect at least 1 argument.
2053 return 1;
2054 }
2055
2056 // Default implementation of GetVariadicInputHomegeneity() returns true to specify that all arguments
2057 // to a variadic input should be of the same type.
2059 return true;
2060 }
2061
2062 // Default implementation of GetVariadicOutputMinArity() returns 1 to specify that a variadic output
2063 // should produce at least 1 output value.
2065 return 1;
2066 }
2067
2068 // Default implementation of GetVariadicOutputHomegeneity() returns true to specify that all output values
2069 // produced by a variadic output should be of the same type.
2071 return true;
2072 }
2073
2074 // Declare list of session config entries used by this Custom Op.
2075 // Implement this function in order to get configs from CustomOpBase::GetSessionConfigs().
2076 // This default implementation returns an empty vector of config entries.
2077 std::vector<std::string> GetSessionConfigKeys() const {
2078 return std::vector<std::string>{};
2079 }
2080
2081 protected:
2082 // Helper function that returns a map of session config entries specified by CustomOpBase::GetSessionConfigKeys.
2083 void GetSessionConfigs(std::unordered_map<std::string, std::string>& out, ConstSessionOptions options) const;
2084};
2085
2086} // namespace Ort
2087
2088#include "onnxruntime_cxx_inline.h"
struct OrtMemoryInfo OrtMemoryInfo
Definition: onnxruntime_c_api.h:271
struct OrtKernelInfo OrtKernelInfo
Definition: onnxruntime_c_api.h:350
OrtLoggingLevel
Logging severity levels.
Definition: onnxruntime_c_api.h:226
OrtMemoryInfoDeviceType
This mimics OrtDevice type constants so they can be returned in the API.
Definition: onnxruntime_c_api.h:374
void(* OrtLoggingFunction)(void *param, OrtLoggingLevel severity, const char *category, const char *logid, const char *code_location, const char *message)
Definition: onnxruntime_c_api.h:315
void(* OrtCustomJoinThreadFn)(OrtCustomThreadHandle ort_custom_thread_handle)
Custom thread join function.
Definition: onnxruntime_c_api.h:676
OrtCustomOpInputOutputCharacteristic
Definition: onnxruntime_c_api.h:4297
struct OrtTensorRTProviderOptionsV2 OrtTensorRTProviderOptionsV2
Definition: onnxruntime_c_api.h:288
struct OrtOpAttr OrtOpAttr
Definition: onnxruntime_c_api.h:293
struct OrtThreadingOptions OrtThreadingOptions
Definition: onnxruntime_c_api.h:285
struct OrtSequenceTypeInfo OrtSequenceTypeInfo
Definition: onnxruntime_c_api.h:279
struct OrtDnnlProviderOptions OrtDnnlProviderOptions
Definition: onnxruntime_c_api.h:291
OrtSparseIndicesFormat
Definition: onnxruntime_c_api.h:215
struct OrtPrepackedWeightsContainer OrtPrepackedWeightsContainer
Definition: onnxruntime_c_api.h:287
struct OrtCustomOpDomain OrtCustomOpDomain
Definition: onnxruntime_c_api.h:282
struct OrtIoBinding OrtIoBinding
Definition: onnxruntime_c_api.h:272
OrtAllocatorType
Definition: onnxruntime_c_api.h:356
struct OrtOp OrtOp
Definition: onnxruntime_c_api.h:292
struct OrtModelMetadata OrtModelMetadata
Definition: onnxruntime_c_api.h:283
struct OrtTypeInfo OrtTypeInfo
Definition: onnxruntime_c_api.h:276
struct OrtTensorTypeAndShapeInfo OrtTensorTypeAndShapeInfo
Definition: onnxruntime_c_api.h:277
struct OrtCUDAProviderOptionsV2 OrtCUDAProviderOptionsV2
Definition: onnxruntime_c_api.h:289
struct OrtKernelContext OrtKernelContext
Definition: onnxruntime_c_api.h:352
struct OrtCANNProviderOptions OrtCANNProviderOptions
Definition: onnxruntime_c_api.h:290
struct OrtSessionOptions OrtSessionOptions
Definition: onnxruntime_c_api.h:281
struct OrtValue OrtValue
Definition: onnxruntime_c_api.h:274
GraphOptimizationLevel
Graph optimization level.
Definition: onnxruntime_c_api.h:324
OrtMemType
Memory types for allocated memory, execution provider specific types should be extended in each provi...
Definition: onnxruntime_c_api.h:365
OrtSparseFormat
Definition: onnxruntime_c_api.h:207
ONNXType
Definition: onnxruntime_c_api.h:195
struct OrtEnv OrtEnv
Definition: onnxruntime_c_api.h:269
OrtErrorCode
Definition: onnxruntime_c_api.h:234
struct OrtStatus OrtStatus
Definition: onnxruntime_c_api.h:270
#define ORT_API_VERSION
The API version defined in this header.
Definition: onnxruntime_c_api.h:40
struct OrtLogger OrtLogger
Definition: onnxruntime_c_api.h:294
struct OrtMapTypeInfo OrtMapTypeInfo
Definition: onnxruntime_c_api.h:278
struct OrtArenaCfg OrtArenaCfg
Definition: onnxruntime_c_api.h:286
ExecutionMode
Definition: onnxruntime_c_api.h:331
OrtOpAttrType
Definition: onnxruntime_c_api.h:249
OrtCustomThreadHandle(* OrtCustomCreateThreadFn)(void *ort_custom_thread_creation_options, OrtThreadWorkerFn ort_thread_worker_fn, void *ort_worker_fn_param)
Ort custom thread creation function.
Definition: onnxruntime_c_api.h:669
ONNXTensorElementDataType
Definition: onnxruntime_c_api.h:174
const OrtApiBase * OrtGetApiBase(void)
The Onnxruntime library's entry point to access the C API.
@ ORT_LOGGING_LEVEL_WARNING
Warning messages.
Definition: onnxruntime_c_api.h:229
@ OrtMemTypeDefault
The default allocator for execution provider.
Definition: onnxruntime_c_api.h:369
void GetAttr(const OrtKernelInfo *p, const char *name, float &)
void GetAttrs(const OrtKernelInfo *p, const char *name, std::vector< float > &)
std::vector< Value > GetOutputValuesHelper(const OrtIoBinding *binding, OrtAllocator *)
std::vector< std::string > GetOutputNamesHelper(const OrtIoBinding *binding, OrtAllocator *)
void OrtRelease(OrtAllocator *ptr)
Definition: onnxruntime_cxx_api.h:219
std::string MakeCustomOpConfigEntryKey(const char *custom_op_name, const char *config)
All C++ Onnxruntime APIs are defined inside this namespace.
Definition: onnxruntime_cxx_api.h:45
std::unique_ptr< char, detail::AllocatedFree > AllocatedStringPtr
unique_ptr typedef used to own strings allocated by OrtAllocators and release them at the end of the ...
Definition: onnxruntime_cxx_api.h:359
const OrtApi & GetApi() noexcept
This returns a reference to the OrtApi interface in use.
Definition: onnxruntime_cxx_api.h:122
std::string GetBuildInfoString()
This function returns the onnxruntime build information: including git branch, git commit id,...
std::string GetVersionString()
This function returns the onnxruntime version string.
std::vector< std::string > GetAvailableProviders()
This is a C++ wrapper for OrtApi::GetAvailableProviders() and returns a vector of strings representin...
Wrapper around OrtAllocator.
Definition: onnxruntime_cxx_api.h:1463
Allocator(const Session &session, const OrtMemoryInfo *)
Allocator(std::nullptr_t)
Convenience to create a class member and then replace with an instance.
Definition: onnxruntime_cxx_api.h:1464
Wrapper around OrtAllocator default instance that is owned by Onnxruntime.
Definition: onnxruntime_cxx_api.h:1455
AllocatorWithDefaultOptions(std::nullptr_t)
Convenience to create a class member and then replace with an instance.
Definition: onnxruntime_cxx_api.h:1456
it is a structure that represents the configuration of an arena based allocator
Definition: onnxruntime_cxx_api.h:1521
ArenaCfg(std::nullptr_t)
Create an empty ArenaCfg object, must be assigned a valid one to be used.
Definition: onnxruntime_cxx_api.h:1522
ArenaCfg(size_t max_mem, int arena_extend_strategy, int initial_chunk_size_bytes, int max_dead_bytes_per_chunk)
bfloat16 (Brain Floating Point) data type
Definition: onnxruntime_cxx_api.h:202
uint16_t value
Definition: onnxruntime_cxx_api.h:203
constexpr bool operator!=(const BFloat16_t &rhs) const noexcept
Definition: onnxruntime_cxx_api.h:208
constexpr BFloat16_t(uint16_t v) noexcept
Definition: onnxruntime_cxx_api.h:205
constexpr bool operator==(const BFloat16_t &rhs) const noexcept
Definition: onnxruntime_cxx_api.h:207
constexpr BFloat16_t() noexcept
Definition: onnxruntime_cxx_api.h:204
This entire structure is deprecated, but we not marking it as a whole yet since we want to preserve f...
Definition: onnxruntime_cxx_api.h:1817
size_t KernelContext_GetOutputCount(const OrtKernelContext *context)
size_t GetDimensionsCount(const OrtTensorTypeAndShapeInfo *info)
void * KernelContext_GetGPUComputeStream(const OrtKernelContext *context)
size_t KernelContext_GetInputCount(const OrtKernelContext *context)
void InvokeOp(const OrtKernelContext *context, const OrtOp *ort_op, const OrtValue *const *input_values, int input_count, OrtValue *const *output_values, int output_count)
OrtOpAttr * CreateOpAttr(const char *name, const void *data, int len, OrtOpAttrType type)
void ReleaseOp(OrtOp *ort_op)
OrtValue * KernelContext_GetOutput(OrtKernelContext *context, size_t index, const int64_t *dim_values, size_t dim_count)
void ReleaseTensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo *input)
T KernelInfoGetAttribute(const OrtKernelInfo *info, const char *name)
OrtTensorTypeAndShapeInfo * GetTensorTypeAndShape(const OrtValue *value)
std::vector< int64_t > GetTensorShape(const OrtTensorTypeAndShapeInfo *info)
void GetDimensions(const OrtTensorTypeAndShapeInfo *info, int64_t *dim_values, size_t dim_values_length)
void ReleaseOpAttr(OrtOpAttr *op_attr)
void ThrowOnError(OrtStatus *result)
size_t GetTensorShapeElementCount(const OrtTensorTypeAndShapeInfo *info)
ONNXTensorElementDataType GetTensorElementType(const OrtTensorTypeAndShapeInfo *info)
OrtOp * CreateOp(const OrtKernelInfo *info, const char *op_name, const char *domain, int version, const char **type_constraint_names, const ONNXTensorElementDataType *type_constraint_values, int type_constraint_count, const OrtOpAttr *const *attr_values, int attr_count, int input_count, int output_count)
CustomOpApi(const OrtApi &api)
Definition: onnxruntime_cxx_api.h:1818
void SetDimensions(OrtTensorTypeAndShapeInfo *info, const int64_t *dim_values, size_t dim_count)
OrtKernelInfo * CopyKernelInfo(const OrtKernelInfo *info)
void ReleaseKernelInfo(OrtKernelInfo *info_copy)
T * GetTensorMutableData(OrtValue *value)
const OrtValue * KernelContext_GetInput(const OrtKernelContext *context, size_t index)
const OrtMemoryInfo * GetTensorMemoryInfo(const OrtValue *value)
const T * GetTensorData(const OrtValue *value)
Definition: onnxruntime_cxx_api.h:1999
std::vector< std::string > GetSessionConfigKeys() const
Definition: onnxruntime_cxx_api.h:2077
OrtCustomOpInputOutputCharacteristic GetOutputCharacteristic(size_t) const
Definition: onnxruntime_cxx_api.h:2041
bool GetVariadicInputHomogeneity() const
Definition: onnxruntime_cxx_api.h:2058
CustomOpBase()
Definition: onnxruntime_cxx_api.h:2000
bool GetVariadicOutputHomogeneity() const
Definition: onnxruntime_cxx_api.h:2070
OrtMemType GetInputMemoryType(size_t) const
Definition: onnxruntime_cxx_api.h:2046
int GetVariadicInputMinArity() const
Definition: onnxruntime_cxx_api.h:2052
const char * GetExecutionProviderType() const
Definition: onnxruntime_cxx_api.h:2033
int GetVariadicOutputMinArity() const
Definition: onnxruntime_cxx_api.h:2064
OrtCustomOpInputOutputCharacteristic GetInputCharacteristic(size_t) const
Definition: onnxruntime_cxx_api.h:2037
void GetSessionConfigs(std::unordered_map< std::string, std::string > &out, ConstSessionOptions options) const
Class that represents session configuration entries for one or more custom operators.
Definition: onnxruntime_cxx_api.h:498
~CustomOpConfigs()=default
CustomOpConfigs & AddConfig(const char *custom_op_name, const char *config_key, const char *config_value)
Adds a session configuration entry/value for a specific custom operator.
CustomOpConfigs & operator=(CustomOpConfigs &&o)=default
CustomOpConfigs(CustomOpConfigs &&o)=default
CustomOpConfigs()=default
const std::unordered_map< std::string, std::string > & GetFlattenedConfigs() const
Returns a flattened map of custom operator configuration entries and their values.
CustomOpConfigs(const CustomOpConfigs &)=default
CustomOpConfigs & operator=(const CustomOpConfigs &)=default
Custom Op Domain.
Definition: onnxruntime_cxx_api.h:441
CustomOpDomain(std::nullptr_t)
Create an empty CustomOpDomain object, must be assigned a valid one to be used.
Definition: onnxruntime_cxx_api.h:442
CustomOpDomain(const char *domain)
Wraps OrtApi::CreateCustomOpDomain.
void Add(const OrtCustomOp *op)
Wraps CustomOpDomain_Add.
The Env (Environment)
Definition: onnxruntime_cxx_api.h:411
Env & EnableTelemetryEvents()
Wraps OrtApi::EnableTelemetryEvents.
Env(OrtEnv *p)
C Interop Helper.
Definition: onnxruntime_cxx_api.h:428
Env(std::nullptr_t)
Create an empty Env object, must be assigned a valid one to be used.
Definition: onnxruntime_cxx_api.h:412
Env(OrtLoggingLevel logging_level=ORT_LOGGING_LEVEL_WARNING, const char *logid="")
Wraps OrtApi::CreateEnv.
Env(const OrtThreadingOptions *tp_options, OrtLoggingLevel logging_level=ORT_LOGGING_LEVEL_WARNING, const char *logid="")
Wraps OrtApi::CreateEnvWithGlobalThreadPools.
Env(const OrtThreadingOptions *tp_options, OrtLoggingFunction logging_function, void *logger_param, OrtLoggingLevel logging_level=ORT_LOGGING_LEVEL_WARNING, const char *logid="")
Wraps OrtApi::CreateEnvWithCustomLoggerAndGlobalThreadPools.
Env(OrtLoggingLevel logging_level, const char *logid, OrtLoggingFunction logging_function, void *logger_param)
Wraps OrtApi::CreateEnvWithCustomLogger.
Env & CreateAndRegisterAllocator(const OrtMemoryInfo *mem_info, const OrtArenaCfg *arena_cfg)
Wraps OrtApi::CreateAndRegisterAllocator.
Env & UpdateEnvWithCustomLogLevel(OrtLoggingLevel log_severity_level)
Wraps OrtApi::UpdateEnvWithCustomLogLevel.
Env & DisableTelemetryEvents()
Wraps OrtApi::DisableTelemetryEvents.
All C++ methods that can fail will throw an exception of this type.
Definition: onnxruntime_cxx_api.h:51
const char * what() const noexcept override
Definition: onnxruntime_cxx_api.h:55
OrtErrorCode GetOrtErrorCode() const
Definition: onnxruntime_cxx_api.h:54
Exception(std::string &&string, OrtErrorCode code)
Definition: onnxruntime_cxx_api.h:52
IEEE 754 half-precision floating point data type.
Definition: onnxruntime_cxx_api.h:183
constexpr bool operator!=(const Float16_t &rhs) const noexcept
Definition: onnxruntime_cxx_api.h:189
constexpr Float16_t(uint16_t v) noexcept
Definition: onnxruntime_cxx_api.h:186
uint16_t value
Definition: onnxruntime_cxx_api.h:184
constexpr bool operator==(const Float16_t &rhs) const noexcept
Definition: onnxruntime_cxx_api.h:188
constexpr Float16_t() noexcept
Definition: onnxruntime_cxx_api.h:185
Definition: onnxruntime_cxx_api.h:83
static const OrtApi * api_
Definition: onnxruntime_cxx_api.h:84
Wrapper around OrtIoBinding.
Definition: onnxruntime_cxx_api.h:1510
UnownedIoBinding GetUnowned() const
Definition: onnxruntime_cxx_api.h:1514
ConstIoBinding GetConst() const
Definition: onnxruntime_cxx_api.h:1513
IoBinding(Session &session)
IoBinding(std::nullptr_t)
Create an empty object for convenience. Sometimes, we want to initialize members later.
Definition: onnxruntime_cxx_api.h:1511
This class wraps a raw pointer OrtKernelContext* that is being passed to the custom kernel Compute() ...
Definition: onnxruntime_cxx_api.h:1703
KernelContext(OrtKernelContext *context)
Logger GetLogger() const
ConstValue GetInput(size_t index) const
OrtAllocator * GetAllocator(const OrtMemoryInfo &memory_info) const
void * GetGPUComputeStream() const
size_t GetInputCount() const
size_t GetOutputCount() const
UnownedValue GetOutput(size_t index, const std::vector< int64_t > &dims) const
UnownedValue GetOutput(size_t index, const int64_t *dim_values, size_t dim_count) const
This struct owns the OrtKernInfo* pointer when a copy is made. For convenient wrapping of OrtKernelIn...
Definition: onnxruntime_cxx_api.h:1777
KernelInfo(OrtKernelInfo *info)
Take ownership of the instance.
ConstKernelInfo GetConst() const
Definition: onnxruntime_cxx_api.h:1780
KernelInfo(std::nullptr_t)
Create an empty instance to initialize later.
Definition: onnxruntime_cxx_api.h:1778
This class represents an ONNX Runtime logger that can be used to log information with an associated s...
Definition: onnxruntime_cxx_api.h:1625
Logger(Logger &&v) noexcept=default
Logger & operator=(Logger &&v) noexcept=default
Logger & operator=(const Logger &)=default
~Logger()=default
Logger(const Logger &)=default
Logger()=default
Logger(std::nullptr_t)
Definition: onnxruntime_cxx_api.h:1634
Logger(const OrtLogger *logger)
OrtLoggingLevel GetLoggingSeverityLevel() const noexcept
Wrapper around OrtMapTypeInfo.
Definition: onnxruntime_cxx_api.h:928
ConstMapTypeInfo GetConst() const
Definition: onnxruntime_cxx_api.h:931
MapTypeInfo(OrtMapTypeInfo *p)
Used for interop with the C API.
Definition: onnxruntime_cxx_api.h:930
MapTypeInfo(std::nullptr_t)
Create an empty MapTypeInfo object, must be assigned a valid one to be used.
Definition: onnxruntime_cxx_api.h:929
Represents native memory allocation coming from one of the OrtAllocators registered with OnnxRuntime....
Definition: onnxruntime_cxx_api.h:1421
MemoryAllocation(MemoryAllocation &&) noexcept
MemoryAllocation & operator=(const MemoryAllocation &)=delete
MemoryAllocation(const MemoryAllocation &)=delete
MemoryAllocation(OrtAllocator *allocator, void *p, size_t size)
size_t size() const
Definition: onnxruntime_cxx_api.h:1430
Wrapper around OrtMemoryInfo.
Definition: onnxruntime_cxx_api.h:835
MemoryInfo(const char *name, OrtAllocatorType type, int id, OrtMemType mem_type)
MemoryInfo(std::nullptr_t)
No instance is created.
Definition: onnxruntime_cxx_api.h:837
MemoryInfo(OrtMemoryInfo *p)
Take ownership of a pointer created by C Api.
Definition: onnxruntime_cxx_api.h:838
static MemoryInfo CreateCpu(OrtAllocatorType type, OrtMemType mem_type1)
ConstMemoryInfo GetConst() const
Definition: onnxruntime_cxx_api.h:840
Wrapper around OrtModelMetadata.
Definition: onnxruntime_cxx_api.h:633
AllocatedStringPtr GetDescriptionAllocated(OrtAllocator *allocator) const
Returns a copy of the description.
std::vector< AllocatedStringPtr > GetCustomMetadataMapKeysAllocated(OrtAllocator *allocator) const
Returns a vector of copies of the custom metadata keys.
ModelMetadata(std::nullptr_t)
Create an empty ModelMetadata object, must be assigned a valid one to be used.
Definition: onnxruntime_cxx_api.h:634
AllocatedStringPtr GetGraphDescriptionAllocated(OrtAllocator *allocator) const
Returns a copy of the graph description.
AllocatedStringPtr GetProducerNameAllocated(OrtAllocator *allocator) const
Returns a copy of the producer name.
AllocatedStringPtr GetGraphNameAllocated(OrtAllocator *allocator) const
Returns a copy of the graph name.
AllocatedStringPtr LookupCustomMetadataMapAllocated(const char *key, OrtAllocator *allocator) const
Looks up a value by a key in the Custom Metadata map.
ModelMetadata(OrtModelMetadata *p)
Used for interop with the C API.
Definition: onnxruntime_cxx_api.h:635
AllocatedStringPtr GetDomainAllocated(OrtAllocator *allocator) const
Returns a copy of the domain name.
int64_t GetVersion() const
Wraps OrtApi::ModelMetadataGetVersion.
This struct provides life time management for custom op attribute.
Definition: onnxruntime_cxx_api.h:1541
OpAttr(const char *name, const void *data, int len, OrtOpAttrType type)
Create and own custom defined operation.
Definition: onnxruntime_cxx_api.h:1786
Op(OrtOp *)
Take ownership of the OrtOp.
static Op Create(const OrtKernelInfo *info, const char *op_name, const char *domain, int version, const char **type_constraint_names, const ONNXTensorElementDataType *type_constraint_values, size_t type_constraint_count, const OpAttr *attr_values, size_t attr_count, size_t input_count, size_t output_count)
Op(std::nullptr_t)
Create an empty Operator object, must be assigned a valid one to be used.
Definition: onnxruntime_cxx_api.h:1787
void Invoke(const OrtKernelContext *context, const OrtValue *const *input_values, size_t input_count, OrtValue *const *output_values, size_t output_count)
void Invoke(const OrtKernelContext *context, const Value *input_values, size_t input_count, Value *output_values, size_t output_count)
RunOptions.
Definition: onnxruntime_cxx_api.h:454
int GetRunLogSeverityLevel() const
Wraps OrtApi::RunOptionsGetRunLogSeverityLevel.
RunOptions & SetTerminate()
Terminates all currently executing Session::Run calls that were made using this RunOptions instance.
RunOptions & SetRunTag(const char *run_tag)
wraps OrtApi::RunOptionsSetRunTag
RunOptions & UnsetTerminate()
Clears the terminate flag so this RunOptions instance can be used in a new Session::Run call without ...
int GetRunLogVerbosityLevel() const
Wraps OrtApi::RunOptionsGetRunLogVerbosityLevel.
RunOptions(std::nullptr_t)
Create an empty RunOptions object, must be assigned a valid one to be used.
Definition: onnxruntime_cxx_api.h:455
RunOptions & SetRunLogVerbosityLevel(int)
Wraps OrtApi::RunOptionsSetRunLogVerbosityLevel.
RunOptions & SetRunLogSeverityLevel(int)
Wraps OrtApi::RunOptionsSetRunLogSeverityLevel.
RunOptions & AddConfigEntry(const char *config_key, const char *config_value)
Wraps OrtApi::AddRunConfigEntry.
const char * GetRunTag() const
Wraps OrtApi::RunOptionsGetRunTag.
RunOptions()
Wraps OrtApi::CreateRunOptions.
Wrapper around OrtSequenceTypeInfo.
Definition: onnxruntime_cxx_api.h:893
SequenceTypeInfo(std::nullptr_t)
Create an empty SequenceTypeInfo object, must be assigned a valid one to be used.
Definition: onnxruntime_cxx_api.h:894
ConstSequenceTypeInfo GetConst() const
Definition: onnxruntime_cxx_api.h:896
SequenceTypeInfo(OrtSequenceTypeInfo *p)
Used for interop with the C API.
Definition: onnxruntime_cxx_api.h:895
Wrapper around OrtSession.
Definition: onnxruntime_cxx_api.h:799
Session(std::nullptr_t)
Create an empty Session object, must be assigned a valid one to be used.
Definition: onnxruntime_cxx_api.h:800
UnownedSession GetUnowned() const
Definition: onnxruntime_cxx_api.h:809
Session(const Env &env, const char *model_path, const SessionOptions &options, OrtPrepackedWeightsContainer *prepacked_weights_container)
Wraps OrtApi::CreateSessionWithPrepackedWeightsContainer.
Session(const Env &env, const void *model_data, size_t model_data_length, const SessionOptions &options, OrtPrepackedWeightsContainer *prepacked_weights_container)
Wraps OrtApi::CreateSessionFromArrayWithPrepackedWeightsContainer.
Session(const Env &env, const char *model_path, const SessionOptions &options)
Wraps OrtApi::CreateSession.
ConstSession GetConst() const
Definition: onnxruntime_cxx_api.h:808
Session(const Env &env, const void *model_data, size_t model_data_length, const SessionOptions &options)
Wraps OrtApi::CreateSessionFromArray.
Wrapper around OrtSessionOptions.
Definition: onnxruntime_cxx_api.h:622
SessionOptions(std::nullptr_t)
Create an empty SessionOptions object, must be assigned a valid one to be used.
Definition: onnxruntime_cxx_api.h:623
UnownedSessionOptions GetUnowned() const
Definition: onnxruntime_cxx_api.h:626
SessionOptions()
Wraps OrtApi::CreateSessionOptions.
ConstSessionOptions GetConst() const
Definition: onnxruntime_cxx_api.h:627
SessionOptions(OrtSessionOptions *p)
Used for interop with the C API.
Definition: onnxruntime_cxx_api.h:625
The Status that holds ownership of OrtStatus received from C API Use it to safely destroy OrtStatus* ...
Definition: onnxruntime_cxx_api.h:365
OrtErrorCode GetErrorCode() const
Status(const char *message, OrtErrorCode code) noexcept
Creates status instance out of null-terminated string message.
bool IsOK() const noexcept
Returns true if instance represents an OK (non-error) status.
Status(OrtStatus *status) noexcept
Takes ownership of OrtStatus instance returned from the C API.
std::string GetErrorMessage() const
Status(const Exception &) noexcept
Creates status instance out of exception.
Status(const std::exception &) noexcept
Creates status instance out of exception.
Status(std::nullptr_t) noexcept
Create an empty object, must be assigned a valid one to be used.
Definition: onnxruntime_cxx_api.h:366
Wrapper around OrtTensorTypeAndShapeInfo.
Definition: onnxruntime_cxx_api.h:872
TensorTypeAndShapeInfo(std::nullptr_t)
Create an empty TensorTypeAndShapeInfo object, must be assigned a valid one to be used.
Definition: onnxruntime_cxx_api.h:873
ConstTensorTypeAndShapeInfo GetConst() const
Definition: onnxruntime_cxx_api.h:875
TensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo *p)
Used for interop with the C API.
Definition: onnxruntime_cxx_api.h:874
The ThreadingOptions.
Definition: onnxruntime_cxx_api.h:380
ThreadingOptions & SetGlobalCustomThreadCreationOptions(void *ort_custom_thread_creation_options)
Wraps OrtApi::SetGlobalCustomThreadCreationOptions.
ThreadingOptions()
Wraps OrtApi::CreateThreadingOptions.
ThreadingOptions & SetGlobalInterOpNumThreads(int inter_op_num_threads)
Wraps OrtApi::SetGlobalInterOpNumThreads.
ThreadingOptions & SetGlobalCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn)
Wraps OrtApi::SetGlobalCustomCreateThreadFn.
ThreadingOptions & SetGlobalCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn)
Wraps OrtApi::SetGlobalCustomJoinThreadFn.
ThreadingOptions & SetGlobalSpinControl(int allow_spinning)
Wraps OrtApi::SetGlobalSpinControl.
ThreadingOptions & SetGlobalDenormalAsZero()
Wraps OrtApi::SetGlobalDenormalAsZero.
ThreadingOptions & SetGlobalIntraOpNumThreads(int intra_op_num_threads)
Wraps OrtApi::SetGlobalIntraOpNumThreads.
Type information that may contain either TensorTypeAndShapeInfo or the information about contained se...
Definition: onnxruntime_cxx_api.h:959
TypeInfo(std::nullptr_t)
Create an empty TypeInfo object, must be assigned a valid one to be used.
Definition: onnxruntime_cxx_api.h:960
ConstTypeInfo GetConst() const
Definition: onnxruntime_cxx_api.h:963
TypeInfo(OrtTypeInfo *p)
C API Interop.
Definition: onnxruntime_cxx_api.h:961
Wrapper around OrtValue.
Definition: onnxruntime_cxx_api.h:1295
static Value CreateMap(Value &keys, Value &values)
Wraps OrtApi::CreateValue.
static Value CreateSparseTensor(const OrtMemoryInfo *info, void *p_data, const Shape &dense_shape, const Shape &values_shape, ONNXTensorElementDataType type)
Creates an OrtValue instance containing SparseTensor. This constructs a sparse tensor that makes use ...
static Value CreateSparseTensor(const OrtMemoryInfo *info, T *p_data, const Shape &dense_shape, const Shape &values_shape)
This is a simple forwarding method to the other overload that helps deducing data type enum value fro...
Value & operator=(Value &&)=default
static Value CreateSparseTensor(OrtAllocator *allocator, const Shape &dense_shape, ONNXTensorElementDataType type)
Creates an instance of OrtValue containing sparse tensor. The created instance has no data....
Value(Value &&)=default
Value(std::nullptr_t)
Create an empty Value object, must be assigned a valid one to be used.
Definition: onnxruntime_cxx_api.h:1300
static Value CreateTensor(const OrtMemoryInfo *info, T *p_data, size_t p_data_element_count, const int64_t *shape, size_t shape_len)
Creates a tensor with a user supplied buffer. Wraps OrtApi::CreateTensorWithDataAsOrtValue.
Value(OrtValue *p)
Used for interop with the C API.
Definition: onnxruntime_cxx_api.h:1301
static Value CreateSparseTensor(OrtAllocator *allocator, const Shape &dense_shape)
This is a simple forwarding method to the below CreateSparseTensor. This helps to specify data type e...
static Value CreateTensor(OrtAllocator *allocator, const int64_t *shape, size_t shape_len, ONNXTensorElementDataType type)
Creates a tensor using a supplied OrtAllocator. Wraps OrtApi::CreateTensorAsOrtValue.
UnownedValue GetUnowned() const
Definition: onnxruntime_cxx_api.h:1306
static Value CreateTensor(const OrtMemoryInfo *info, void *p_data, size_t p_data_byte_count, const int64_t *shape, size_t shape_len, ONNXTensorElementDataType type)
Creates a tensor with a user supplied buffer. Wraps OrtApi::CreateTensorWithDataAsOrtValue.
static Value CreateOpaque(const char *domain, const char *type_name, const T &)
Wraps OrtApi::CreateOpaqueValue.
static Value CreateTensor(OrtAllocator *allocator, const int64_t *shape, size_t shape_len)
Creates a tensor using a supplied OrtAllocator. Wraps OrtApi::CreateTensorAsOrtValue.
static Value CreateSequence(std::vector< Value > &values)
Wraps OrtApi::CreateValue.
ConstValue GetConst() const
Definition: onnxruntime_cxx_api.h:1305
Definition: onnxruntime_cxx_api.h:338
AllocatedFree(OrtAllocator *allocator)
Definition: onnxruntime_cxx_api.h:340
OrtAllocator * allocator_
Definition: onnxruntime_cxx_api.h:339
void operator()(void *ptr) const
Definition: onnxruntime_cxx_api.h:342
Definition: onnxruntime_cxx_api.h:1440
ConstMemoryInfo GetInfo() const
void * Alloc(size_t size)
MemoryAllocation GetAllocation(size_t size)
Base & operator=(Base &&v) noexcept
Definition: onnxruntime_cxx_api.h:325
typename Unowned< T >::Type contained_type
Definition: onnxruntime_cxx_api.h:314
Base(Base &&v) noexcept
Definition: onnxruntime_cxx_api.h:324
Base(const Base &)=default
constexpr Base(contained_type *p) noexcept
Definition: onnxruntime_cxx_api.h:317
Base & operator=(const Base &)=default
Used internally by the C++ API. C++ wrapper types inherit from this. This is a zero cost abstraction ...
Definition: onnxruntime_cxx_api.h:270
Base(Base &&v) noexcept
Definition: onnxruntime_cxx_api.h:280
constexpr Base()=default
contained_type * release()
Relinquishes ownership of the contained C object pointer The underlying object is not destroyed.
Definition: onnxruntime_cxx_api.h:291
Base(const Base &)=delete
constexpr Base(contained_type *p) noexcept
Definition: onnxruntime_cxx_api.h:274
Base & operator=(const Base &)=delete
Base & operator=(Base &&v) noexcept
Definition: onnxruntime_cxx_api.h:281
contained_type * p_
Definition: onnxruntime_cxx_api.h:298
~Base()
Definition: onnxruntime_cxx_api.h:275
T contained_type
Definition: onnxruntime_cxx_api.h:271
Definition: onnxruntime_cxx_api.h:1478
std::vector< Value > GetOutputValues(OrtAllocator *) const
std::vector< std::string > GetOutputNames(OrtAllocator *) const
std::vector< Value > GetOutputValues() const
std::vector< std::string > GetOutputNames() const
Definition: onnxruntime_cxx_api.h:706
TypeInfo GetInputTypeInfo(size_t index) const
Wraps OrtApi::SessionGetInputTypeInfo.
size_t GetOutputCount() const
Returns the number of model outputs.
uint64_t GetProfilingStartTimeNs() const
Wraps OrtApi::SessionGetProfilingStartTimeNs.
ModelMetadata GetModelMetadata() const
Wraps OrtApi::SessionGetModelMetadata.
size_t GetInputCount() const
Returns the number of model inputs.
TypeInfo GetOutputTypeInfo(size_t index) const
Wraps OrtApi::SessionGetOutputTypeInfo.
AllocatedStringPtr GetOverridableInitializerNameAllocated(size_t index, OrtAllocator *allocator) const
Returns a copy of the overridable initializer name at then specified index.
AllocatedStringPtr GetOutputNameAllocated(size_t index, OrtAllocator *allocator) const
Returns a copy of output name at then specified index.
size_t GetOverridableInitializerCount() const
Returns the number of inputs that have defaults that can be overridden.
AllocatedStringPtr GetInputNameAllocated(size_t index, OrtAllocator *allocator) const
Returns a copy of input name at the specified index.
TypeInfo GetOverridableInitializerTypeInfo(size_t index) const
Wraps OrtApi::SessionGetOverridableInitializerTypeInfo.
Definition: onnxruntime_cxx_api.h:541
std::string GetConfigEntry(const char *config_key) const
Wraps OrtApi::GetSessionConfigEntry.
std::string GetConfigEntryOrDefault(const char *config_key, const std::string &def)
SessionOptions Clone() const
Creates and returns a copy of this SessionOptions object. Wraps OrtApi::CloneSessionOptions.
bool HasConfigEntry(const char *config_key) const
Wraps OrtApi::HasSessionConfigEntry.
Definition: onnxruntime_cxx_api.h:992
void GetStringTensorContent(void *buffer, size_t buffer_length, size_t *offsets, size_t offsets_count) const
The API copies all of the UTF-8 encoded string data contained within a tensor or a sparse tensor into...
void GetStringTensorElement(size_t buffer_length, size_t element_index, void *buffer) const
The API copies UTF-8 encoded bytes for the requested string element contained within a tensor or a sp...
TensorTypeAndShapeInfo GetSparseTensorIndicesTypeShapeInfo(OrtSparseIndicesFormat format) const
The API returns type and shape information for the specified indices. Each supported indices have the...
const void * GetTensorRawData() const
Returns a non-typed pointer to a tensor contained data.
std::string GetStringTensorElement(size_t element_index) const
Returns string tensor UTF-8 encoded string element. Use of this API is recommended over GetStringTens...
size_t GetStringTensorElementLength(size_t element_index) const
The API returns a byte length of UTF-8 encoded string element contained in either a tensor or a spare...
size_t GetStringTensorDataLength() const
This API returns a full length of string data contained within either a tensor or a sparse Tensor....
bool IsSparseTensor() const
Returns true if the OrtValue contains a sparse tensor.
TypeInfo GetTypeInfo() const
The API returns type information for data contained in a tensor. For sparse tensors it returns type i...
const R * GetSparseTensorIndicesData(OrtSparseIndicesFormat indices_format, size_t &num_indices) const
The API retrieves a pointer to the internal indices buffer. The API merely performs a convenience dat...
bool IsTensor() const
Returns true if Value is a tensor, false for other types like map/sequence/etc.
ConstMemoryInfo GetTensorMemoryInfo() const
This API returns information about the memory allocation used to hold data.
const R * GetSparseTensorValues() const
The API returns a pointer to an internal buffer of the sparse tensor containing non-zero values....
TensorTypeAndShapeInfo GetTensorTypeAndShapeInfo() const
The API returns type information for data contained in a tensor. For sparse tensors it returns type i...
Value GetValue(int index, OrtAllocator *allocator) const
size_t GetCount() const
< Return true if OrtValue contains data and returns false if the OrtValue is a None
void GetOpaqueData(const char *domain, const char *type_name, R &) const
Obtains a pointer to a user defined data for experimental purposes.
TensorTypeAndShapeInfo GetSparseTensorValuesTypeAndShapeInfo() const
The API returns type and shape information for stored non-zero values of the sparse tensor....
const R * GetTensorData() const
Returns a const typed pointer to the tensor contained data. No type checking is performed,...
OrtSparseFormat GetSparseFormat() const
The API returns the sparse data format this OrtValue holds in a sparse tensor. If the sparse tensor w...
Definition: onnxruntime_cxx_api.h:1489
void BindOutput(const char *name, const Value &)
void BindInput(const char *name, const Value &)
void BindOutput(const char *name, const OrtMemoryInfo *)
Definition: onnxruntime_cxx_api.h:1730
Value GetTensorAttribute(const char *name, OrtAllocator *allocator) const
TypeInfo GetInputTypeInfo(size_t index) const
std::vector< R > GetAttributes(const char *name) const
Definition: onnxruntime_cxx_api.h:1744
R GetAttribute(const char *name) const
Definition: onnxruntime_cxx_api.h:1737
TypeInfo GetOutputTypeInfo(size_t index) const
KernelInfo Copy() const
std::string GetNodeName() const
std::string GetInputName(size_t index) const
size_t GetOutputCount() const
size_t GetInputCount() const
ConstValue GetTensorConstantInput(size_t index, int *is_constant) const
std::string GetOutputName(size_t index) const
Definition: onnxruntime_cxx_api.h:914
ONNXTensorElementDataType GetMapKeyType() const
Wraps OrtApi::GetMapKeyType.
TypeInfo GetMapValueType() const
Wraps OrtApi::GetMapValueType.
Definition: onnxruntime_cxx_api.h:814
std::string GetAllocatorName() const
OrtMemType GetMemoryType() const
OrtMemoryInfoDeviceType GetDeviceType() const
OrtAllocatorType GetAllocatorType() const
bool operator==(const MemoryInfoImpl< U > &o) const
Definition: onnxruntime_cxx_api.h:901
TypeInfo GetOptionalElementType() const
Wraps OrtApi::CastOptionalTypeToContainedTypeInfo.
Definition: onnxruntime_cxx_api.h:975
const char ** str
Definition: onnxruntime_cxx_api.h:980
const int64_t * values_shape
Definition: onnxruntime_cxx_api.h:976
size_t values_shape_len
Definition: onnxruntime_cxx_api.h:977
const void * p_data
Definition: onnxruntime_cxx_api.h:979
Definition: onnxruntime_cxx_api.h:880
TypeInfo GetSequenceElementType() const
Wraps OrtApi::GetSequenceElementType.
Definition: onnxruntime_cxx_api.h:750
AllocatedStringPtr EndProfilingAllocated(OrtAllocator *allocator)
End profiling and return a copy of the profiling file name.
void Run(const RunOptions &run_options, const IoBinding &)
Wraps OrtApi::RunWithBinding.
std::vector< Value > Run(const RunOptions &run_options, const char *const *input_names, const Value *input_values, size_t input_count, const char *const *output_names, size_t output_count)
Run the model returning results in an Ort allocated vector.
void Run(const RunOptions &run_options, const char *const *input_names, const Value *input_values, size_t input_count, const char *const *output_names, Value *output_values, size_t output_count)
Run the model returning results in user provided outputs Same as Run(const RunOptions&,...
Definition: onnxruntime_cxx_api.h:553
SessionOptionsImpl & DisableMemPattern()
Wraps OrtApi::DisableMemPattern.
SessionOptionsImpl & SetCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn)
Wraps OrtApi::SessionOptionsSetCustomJoinThreadFn.
SessionOptionsImpl & SetLogSeverityLevel(int level)
Wraps OrtApi::SetSessionLogSeverityLevel.
SessionOptionsImpl & AppendExecutionProvider(const std::string &provider_name, const std::unordered_map< std::string, std::string > &provider_options={})
Wraps OrtApi::SessionOptionsAppendExecutionProvider. Currently supports QNN, SNPE and XNNPACK.
SessionOptionsImpl & EnableOrtCustomOps()
Wraps OrtApi::EnableOrtCustomOps.
SessionOptionsImpl & SetCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn)
Wraps OrtApi::SessionOptionsSetCustomCreateThreadFn.
SessionOptionsImpl & AppendExecutionProvider_CANN(const OrtCANNProviderOptions &provider_options)
Wraps OrtApi::SessionOptionsAppendExecutionProvider_Dnnl.
SessionOptionsImpl & SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level)
Wraps OrtApi::SetSessionGraphOptimizationLevel.
SessionOptionsImpl & AppendExecutionProvider_MIGraphX(const OrtMIGraphXProviderOptions &provider_options)
Wraps OrtApi::SessionOptionsAppendExecutionProvider_CANN.
SessionOptionsImpl & DisableCpuMemArena()
Wraps OrtApi::DisableCpuMemArena.
SessionOptionsImpl & Add(OrtCustomOpDomain *custom_op_domain)
Wraps OrtApi::AddCustomOpDomain.
SessionOptionsImpl & AddConfigEntry(const char *config_key, const char *config_value)
Wraps OrtApi::AddSessionConfigEntry.
SessionOptionsImpl & EnableMemPattern()
Wraps OrtApi::EnableMemPattern.
SessionOptionsImpl & AppendExecutionProvider_Dnnl(const OrtDnnlProviderOptions &provider_options)
SessionOptionsImpl & SetCustomThreadCreationOptions(void *ort_custom_thread_creation_options)
Wraps OrtApi::SessionOptionsSetCustomThreadCreationOptions.
SessionOptionsImpl & AddExternalInitializers(const std::vector< std::string > &names, const std::vector< Value > &ort_values)
Wraps OrtApi::AddExternalInitializers.
SessionOptionsImpl & SetLogId(const char *logid)
Wraps OrtApi::SetSessionLogId.
SessionOptionsImpl & AppendExecutionProvider_CUDA_V2(const OrtCUDAProviderOptionsV2 &provider_options)
Wraps OrtApi::SessionOptionsAppendExecutionProvider_CUDA_V2.
SessionOptionsImpl & SetExecutionMode(ExecutionMode execution_mode)
Wraps OrtApi::SetSessionExecutionMode.
SessionOptionsImpl & DisablePerSessionThreads()
Wraps OrtApi::DisablePerSessionThreads.
SessionOptionsImpl & RegisterCustomOpsLibrary(const char *library_name, const CustomOpConfigs &custom_op_configs={})
SessionOptionsImpl & AppendExecutionProvider_TensorRT_V2(const OrtTensorRTProviderOptionsV2 &provider_options)
Wraps OrtApi::SessionOptionsAppendExecutionProvider_TensorRT.
SessionOptionsImpl & RegisterCustomOpsUsingFunction(const char *function_name)
Wraps OrtApi::RegisterCustomOpsUsingFunction.
SessionOptionsImpl & DisableProfiling()
Wraps OrtApi::DisableProfiling.
SessionOptionsImpl & SetIntraOpNumThreads(int intra_op_num_threads)
Wraps OrtApi::SetIntraOpNumThreads.
SessionOptionsImpl & AppendExecutionProvider_ROCM(const OrtROCMProviderOptions &provider_options)
Wraps OrtApi::SessionOptionsAppendExecutionProvider_ROCM.
SessionOptionsImpl & AppendExecutionProvider_OpenVINO(const OrtOpenVINOProviderOptions &provider_options)
Wraps OrtApi::SessionOptionsAppendExecutionProvider_OpenVINO.
SessionOptionsImpl & EnableCpuMemArena()
Wraps OrtApi::EnableCpuMemArena.
SessionOptionsImpl & AddInitializer(const char *name, const OrtValue *ort_val)
Wraps OrtApi::AddInitializer.
SessionOptionsImpl & SetInterOpNumThreads(int inter_op_num_threads)
Wraps OrtApi::SetInterOpNumThreads.
SessionOptionsImpl & EnableProfiling(const char *profile_file_prefix)
Wraps OrtApi::EnableProfiling.
SessionOptionsImpl & SetOptimizedModelFilePath(const char *optimized_model_file)
Wraps OrtApi::SetOptimizedModelFilePath.
SessionOptionsImpl & AppendExecutionProvider_TensorRT(const OrtTensorRTProviderOptions &provider_options)
Wraps OrtApi::SessionOptionsAppendExecutionProvider_TensorRT.
SessionOptionsImpl & AppendExecutionProvider_CUDA(const OrtCUDAProviderOptions &provider_options)
Wraps OrtApi::SessionOptionsAppendExecutionProvider_CUDA.
Definition: onnxruntime_cxx_api.h:986
const int64_t * shape
Definition: onnxruntime_cxx_api.h:987
size_t shape_len
Definition: onnxruntime_cxx_api.h:988
Definition: onnxruntime_cxx_api.h:845
size_t GetElementCount() const
Wraps OrtApi::GetTensorShapeElementCount.
void GetDimensions(int64_t *values, size_t values_count) const
Wraps OrtApi::GetDimensions.
std::vector< int64_t > GetShape() const
Uses GetDimensionsCount & GetDimensions to return a std::vector of the shape.
void GetSymbolicDimensions(const char **values, size_t values_count) const
Wraps OrtApi::GetSymbolicDimensions.
size_t GetDimensionsCount() const
Wraps OrtApi::GetDimensionsCount.
ONNXTensorElementDataType GetElementType() const
Wraps OrtApi::GetTensorElementType.
Definition: onnxruntime_cxx_api.h:936
ONNXType GetONNXType() const
ConstSequenceTypeInfo GetSequenceTypeInfo() const
Wraps OrtApi::CastTypeInfoToSequenceTypeInfo.
ConstMapTypeInfo GetMapTypeInfo() const
Wraps OrtApi::CastTypeInfoToMapTypeInfo.
ConstOptionalTypeInfo GetOptionalTypeInfo() const
wraps OrtApi::CastTypeInfoToOptionalTypeInfo
ConstTensorTypeAndShapeInfo GetTensorTypeAndShapeInfo() const
Wraps OrtApi::CastTypeInfoToTensorInfo.
This is a tagging template type. Use it with Base<T> to indicate that the C++ interface object has no...
Definition: onnxruntime_cxx_api.h:246
T Type
Definition: onnxruntime_cxx_api.h:247
Definition: onnxruntime_cxx_api.h:1153
void FillStringTensorElement(const char *s, size_t index)
Set a single string in a string tensor.
R * GetTensorMutableData()
Returns a non-const typed pointer to an OrtValue/Tensor contained buffer No type checking is performe...
R & At(const std::vector< int64_t > &location)
void UseBlockSparseIndices(const Shape &indices_shape, int32_t *indices_data)
Supplies BlockSparse format specific indices and marks the contained sparse tensor as being a BlockSp...
void FillSparseTensorBlockSparse(const OrtMemoryInfo *data_mem_info, const OrtSparseValuesParam &values, const Shape &indices_shape, const int32_t *indices_data)
The API will allocate memory using the allocator instance supplied to the CreateSparseTensor() API an...
void * GetTensorMutableRawData()
Returns a non-typed non-const pointer to a tensor contained data.
void UseCooIndices(int64_t *indices_data, size_t indices_num)
Supplies COO format specific indices and marks the contained sparse tensor as being a COO format tens...
void FillSparseTensorCoo(const OrtMemoryInfo *data_mem_info, const OrtSparseValuesParam &values_param, const int64_t *indices_data, size_t indices_num)
The API will allocate memory using the allocator instance supplied to the CreateSparseTensor() API an...
void FillStringTensor(const char *const *s, size_t s_len)
Set all strings at once in a string tensor.
void UseCsrIndices(int64_t *inner_data, size_t inner_num, int64_t *outer_data, size_t outer_num)
Supplies CSR format specific indices and marks the contained sparse tensor as being a CSR format tens...
void FillSparseTensorCsr(const OrtMemoryInfo *data_mem_info, const OrtSparseValuesParam &values, const int64_t *inner_indices_data, size_t inner_indices_num, const int64_t *outer_indices_data, size_t outer_indices_num)
The API will allocate memory using the allocator instance supplied to the CreateSparseTensor() API an...
char * GetResizedStringTensorElementBuffer(size_t index, size_t buffer_length)
Allocate if necessary and obtain a pointer to a UTF-8 encoded string element buffer indexed by the fl...
Memory allocation interface.
Definition: onnxruntime_c_api.h:308
void(* Free)(struct OrtAllocator *this_, void *p)
Free a block of memory previously allocated with OrtAllocator::Alloc.
Definition: onnxruntime_c_api.h:311
const OrtApi *(* GetApi)(uint32_t version)
Get a pointer to the requested version of the OrtApi.
Definition: onnxruntime_c_api.h:636
The C API.
Definition: onnxruntime_c_api.h:687
CUDA Provider Options.
Definition: onnxruntime_c_api.h:392
Definition: onnxruntime_c_api.h:4307
int(* GetVariadicInputHomogeneity)(const struct OrtCustomOp *op)
Definition: onnxruntime_c_api.h:4348
OrtCustomOpInputOutputCharacteristic(* GetOutputCharacteristic)(const struct OrtCustomOp *op, size_t index)
Definition: onnxruntime_c_api.h:4332
size_t(* GetInputTypeCount)(const struct OrtCustomOp *op)
Definition: onnxruntime_c_api.h:4322
int(* GetVariadicOutputMinArity)(const struct OrtCustomOp *op)
Definition: onnxruntime_c_api.h:4352
const char *(* GetName)(const struct OrtCustomOp *op)
Definition: onnxruntime_c_api.h:4315
size_t(* GetOutputTypeCount)(const struct OrtCustomOp *op)
Definition: onnxruntime_c_api.h:4324
void(* KernelDestroy)(void *op_kernel)
Definition: onnxruntime_c_api.h:4328
int(* GetVariadicOutputHomogeneity)(const struct OrtCustomOp *op)
Definition: onnxruntime_c_api.h:4357
OrtMemType(* GetInputMemoryType)(const struct OrtCustomOp *op, size_t index)
Definition: onnxruntime_c_api.h:4339
void *(* CreateKernel)(const struct OrtCustomOp *op, const OrtApi *api, const OrtKernelInfo *info)
Definition: onnxruntime_c_api.h:4311
uint32_t version
Definition: onnxruntime_c_api.h:4308
ONNXTensorElementDataType(* GetInputType)(const struct OrtCustomOp *op, size_t index)
Definition: onnxruntime_c_api.h:4321
OrtCustomOpInputOutputCharacteristic(* GetInputCharacteristic)(const struct OrtCustomOp *op, size_t index)
Definition: onnxruntime_c_api.h:4331
const char *(* GetExecutionProviderType)(const struct OrtCustomOp *op)
Definition: onnxruntime_c_api.h:4318
ONNXTensorElementDataType(* GetOutputType)(const struct OrtCustomOp *op, size_t index)
Definition: onnxruntime_c_api.h:4323
int(* GetVariadicInputMinArity)(const struct OrtCustomOp *op)
Definition: onnxruntime_c_api.h:4343
void(* KernelCompute)(void *op_kernel, OrtKernelContext *context)
Definition: onnxruntime_c_api.h:4327
MIGraphX Provider Options.
Definition: onnxruntime_c_api.h:581
OpenVINO Provider Options.
Definition: onnxruntime_c_api.h:591
ROCM Provider Options.
Definition: onnxruntime_c_api.h:473
TensorRT Provider Options.
Definition: onnxruntime_c_api.h:553