ONNX Runtime
Loading...
Searching...
No Matches
onnxruntime_cxx_api.h
1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4// Summary: The Ort C++ API is a header only wrapper around the Ort C API.
5//
6// The C++ API simplifies usage by returning values directly instead of error codes, throwing exceptions on errors
7// and automatically releasing resources in the destructors. The primary purpose of C++ API is exception safety so
8// all the resources follow RAII and do not leak memory.
9//
10// Each of the C++ wrapper classes holds only a pointer to the C internal object. Treat them like smart pointers.
11// To create an empty object, pass 'nullptr' to the constructor (for example, Env e{nullptr};). However, you can't use them
12// until you assign an instance that actually holds an underlying object.
13//
14// For Ort objects only move assignment between objects is allowed, there are no copy constructors.
15// Some objects have explicit 'Clone' methods for this purpose.
16//
17// ConstXXXX types are copyable since they do not own the underlying C object, so you can pass them to functions as arguments
18// by value or by reference. ConstXXXX types are restricted to const only interfaces.
19//
20// UnownedXXXX are similar to ConstXXXX but also allow non-const interfaces.
21//
22// The lifetime of the corresponding owning object must eclipse the lifetimes of the ConstXXXX/UnownedXXXX types. They exists so you do not
23// have to fallback to C types and the API with the usual pitfalls. In general, do not use C API from your C++ code.
24
25#pragma once
26#include "onnxruntime_c_api.h"
27#include "onnxruntime_float16.h"
28
29#include <array>
30#include <cstddef>
31#include <cstdio>
32#include <memory>
33#include <stdexcept>
34#include <string>
35#include <type_traits>
36#include <unordered_map>
37#include <utility>
38#include <variant>
39#include <vector>
40
41#ifdef ORT_NO_EXCEPTIONS
42#include <iostream>
43#endif
44
48namespace Ort {
49
54struct Exception : std::exception {
55 Exception(const std::string& string, OrtErrorCode code) : message_{string}, code_{code} {}
56 Exception(std::string&& string, OrtErrorCode code) : message_{std::move(string)}, code_{code} {}
57
58 OrtErrorCode GetOrtErrorCode() const { return code_; }
59 const char* what() const noexcept override { return message_.c_str(); }
60
61 private:
62 std::string message_;
63 OrtErrorCode code_;
64};
65
66#ifdef ORT_NO_EXCEPTIONS
67// The #ifndef is for the very special case where the user of this library wants to define their own way of handling errors.
68// NOTE: This header expects control flow to not continue after calling ORT_CXX_API_THROW
69#ifndef ORT_CXX_API_THROW
70#define ORT_CXX_API_THROW(string, code) \
71 do { \
72 std::cerr << Ort::Exception(string, code) \
73 .what() \
74 << std::endl; \
75 abort(); \
76 } while (false)
77#endif
78#else
79#define ORT_CXX_API_THROW(string, code) \
80 throw Ort::Exception(string, code)
81#endif
82
83#ifdef ORT_API_MANUAL_INIT
84// If the macro ORT_API_MANUAL_INIT is defined, no static initialization
85// will be performed. Instead, users must call InitApi() before using the
86// ORT C++ APIs..
87//
88// InitApi() sets the global API object using the default initialization
89// logic. Users call this to initialize the ORT C++ APIs at a time that
90// makes sense in their program.
91inline void InitApi() noexcept;
92
93// InitApi(const OrtApi*) is used by custom operator libraries that are not
94// linked to onnxruntime. It sets the global API object, which is required
95// by the ORT C++ APIs.
96//
97// Example mycustomop.cc:
98//
99// #define ORT_API_MANUAL_INIT
100// #include <onnxruntime_cxx_api.h>
101// #undef ORT_API_MANUAL_INIT
102//
103// OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options, const OrtApiBase* api_base) {
104// Ort::InitApi(api_base->GetApi(ORT_API_VERSION));
105// // ...
106// }
107//
108inline void InitApi(const OrtApi* api) noexcept;
109#endif
110
111namespace detail {
112// This is used internally by the C++ API. This class holds the global
113// variable that points to the OrtApi.
114struct Global {
115 static const OrtApi* Api(const OrtApi* newValue = nullptr) noexcept {
116 // This block-level static will be initialized once when this function is
117 // first executed, delaying the call to DefaultInit() until it is first needed.
118 //
119 // When ORT_API_MANUAL_INIT is not defined, DefaultInit() calls
120 // OrtGetApiBase()->GetApi(), which may result in a shared library being
121 // loaded.
122 //
123 // Using a block-level static instead of a class-level static helps
124 // avoid issues with static initialization order and dynamic libraries
125 // loading other dynamic libraries.
126 //
127 // This makes it safe to include the C++ API headers in a shared library
128 // that is delay loaded or delay loads its dependencies.
129 //
130 // This DOES NOT make it safe to _use_ arbitrary ORT C++ APIs when
131 // initializing static members, however.
132 static const OrtApi* api = DefaultInit();
133
134 if (newValue) {
135 api = newValue;
136 }
137
138 return api;
139 }
140
141 private:
142 // Has different definitions based on ORT_API_MANUAL_INIT
143 static const OrtApi* DefaultInit() noexcept;
144
145#ifdef ORT_API_MANUAL_INIT
146 // Public APIs to set the OrtApi* to use.
147 friend void ::Ort::InitApi() noexcept;
148 friend void ::Ort::InitApi(const OrtApi*) noexcept;
149#endif
150};
151} // namespace detail
152
153#ifdef ORT_API_MANUAL_INIT
154
155// See comments on declaration above for usage.
156inline void InitApi(const OrtApi* api) noexcept { detail::Global::Api(api); }
157inline void InitApi() noexcept { InitApi(OrtGetApiBase()->GetApi(ORT_API_VERSION)); }
158
159#ifdef _MSC_VER
160// If you get a linker error about a mismatch here, you are trying to
161// link two compilation units that have different definitions for
162// ORT_API_MANUAL_INIT together. All compilation units must agree on the
163// definition of ORT_API_MANUAL_INIT.
164#pragma detect_mismatch("ORT_API_MANUAL_INIT", "enabled")
165#endif
166
167inline const OrtApi* detail::Global::DefaultInit() noexcept {
168 // When ORT_API_MANUAL_INIT is defined, there's no default init that can
169 // be done.
170 return nullptr;
171}
172
173#else // ORT_API_MANUAL_INIT
174
175#ifdef _MSC_VER
176// If you get a linker error about a mismatch here, you are trying to link
177// two compilation units that have different definitions for
178// ORT_API_MANUAL_INIT together. All compilation units must agree on the
179// definition of ORT_API_MANUAL_INIT.
180#pragma detect_mismatch("ORT_API_MANUAL_INIT", "disabled")
181#endif
182
183inline const OrtApi* detail::Global::DefaultInit() noexcept {
185}
186#endif // ORT_API_MANUAL_INIT
187
189inline const OrtApi& GetApi() noexcept { return *detail::Global::Api(); }
190
195std::string GetVersionString();
196
202std::string GetBuildInfoString();
203
209std::vector<std::string> GetAvailableProviders();
210
216 auto* api = GetApi().GetModelEditorApi();
217 if (api == nullptr) {
218 // minimal build
219 ORT_CXX_API_THROW("Model Editor API is not available in this build", ORT_FAIL);
220 }
221
222 return *api;
223}
224
230 auto* api = GetApi().GetCompileApi();
231 if (api == nullptr) {
232 // minimal build
233 ORT_CXX_API_THROW("Compile API is not available in this build", ORT_FAIL);
234 }
235
236 return *api;
237}
238
243inline const OrtEpApi& GetEpApi() {
244 auto* api = GetApi().GetEpApi();
245 if (api == nullptr) {
246 // minimal build
247 ORT_CXX_API_THROW("EP API is not available in this build", ORT_FAIL);
248 }
249
250 return *api;
251}
252
271struct Float16_t : onnxruntime_float16::Float16Impl<Float16_t> {
272 private:
278 constexpr explicit Float16_t(uint16_t v) noexcept { val = v; }
279
280 public:
281 using Base = onnxruntime_float16::Float16Impl<Float16_t>;
282
286 Float16_t() = default;
287
293 constexpr static Float16_t FromBits(uint16_t v) noexcept { return Float16_t(v); }
294
299 explicit Float16_t(float v) noexcept { val = Base::ToUint16Impl(v); }
300
305 float ToFloat() const noexcept { return Base::ToFloatImpl(); }
306
311 using Base::IsNegative;
312
317 using Base::IsNaN;
318
323 using Base::IsFinite;
324
329 using Base::IsPositiveInfinity;
330
335 using Base::IsNegativeInfinity;
336
341 using Base::IsInfinity;
342
347 using Base::IsNaNOrZero;
348
353 using Base::IsNormal;
354
359 using Base::IsSubnormal;
360
365 using Base::Abs;
366
371 using Base::Negate;
372
381 using Base::AreZero;
382
386 explicit operator float() const noexcept { return ToFloat(); }
387
388 using Base::operator==;
389 using Base::operator!=;
390 using Base::operator<;
391};
392
393static_assert(sizeof(Float16_t) == sizeof(uint16_t), "Sizes must match");
394
413struct BFloat16_t : onnxruntime_float16::BFloat16Impl<BFloat16_t> {
414 private:
422 constexpr explicit BFloat16_t(uint16_t v) noexcept { val = v; }
423
424 public:
425 using Base = onnxruntime_float16::BFloat16Impl<BFloat16_t>;
426
427 BFloat16_t() = default;
428
434 static constexpr BFloat16_t FromBits(uint16_t v) noexcept { return BFloat16_t(v); }
435
440 explicit BFloat16_t(float v) noexcept { val = Base::ToUint16Impl(v); }
441
446 float ToFloat() const noexcept { return Base::ToFloatImpl(); }
447
452 using Base::IsNegative;
453
458 using Base::IsNaN;
459
464 using Base::IsFinite;
465
470 using Base::IsPositiveInfinity;
471
476 using Base::IsNegativeInfinity;
477
482 using Base::IsInfinity;
483
488 using Base::IsNaNOrZero;
489
494 using Base::IsNormal;
495
500 using Base::IsSubnormal;
501
506 using Base::Abs;
507
512 using Base::Negate;
513
522 using Base::AreZero;
523
527 explicit operator float() const noexcept { return ToFloat(); }
528
529 // We do not have an inherited impl for the below operators
530 // as the internal class implements them a little differently
531 bool operator==(const BFloat16_t& rhs) const noexcept;
532 bool operator!=(const BFloat16_t& rhs) const noexcept { return !(*this == rhs); }
533 bool operator<(const BFloat16_t& rhs) const noexcept;
534};
535
536static_assert(sizeof(BFloat16_t) == sizeof(uint16_t), "Sizes must match");
537
544 uint8_t value;
545 constexpr Float8E4M3FN_t() noexcept : value(0) {}
546 constexpr Float8E4M3FN_t(uint8_t v) noexcept : value(v) {}
547 constexpr operator uint8_t() const noexcept { return value; }
548 // nan values are treated like any other value for operator ==, !=
549 constexpr bool operator==(const Float8E4M3FN_t& rhs) const noexcept { return value == rhs.value; };
550 constexpr bool operator!=(const Float8E4M3FN_t& rhs) const noexcept { return value != rhs.value; };
551};
552
553static_assert(sizeof(Float8E4M3FN_t) == sizeof(uint8_t), "Sizes must match");
554
561 uint8_t value;
562 constexpr Float8E4M3FNUZ_t() noexcept : value(0) {}
563 constexpr Float8E4M3FNUZ_t(uint8_t v) noexcept : value(v) {}
564 constexpr operator uint8_t() const noexcept { return value; }
565 // nan values are treated like any other value for operator ==, !=
566 constexpr bool operator==(const Float8E4M3FNUZ_t& rhs) const noexcept { return value == rhs.value; };
567 constexpr bool operator!=(const Float8E4M3FNUZ_t& rhs) const noexcept { return value != rhs.value; };
568};
569
570static_assert(sizeof(Float8E4M3FNUZ_t) == sizeof(uint8_t), "Sizes must match");
571
578 uint8_t value;
579 constexpr Float8E5M2_t() noexcept : value(0) {}
580 constexpr Float8E5M2_t(uint8_t v) noexcept : value(v) {}
581 constexpr operator uint8_t() const noexcept { return value; }
582 // nan values are treated like any other value for operator ==, !=
583 constexpr bool operator==(const Float8E5M2_t& rhs) const noexcept { return value == rhs.value; };
584 constexpr bool operator!=(const Float8E5M2_t& rhs) const noexcept { return value != rhs.value; };
585};
586
587static_assert(sizeof(Float8E5M2_t) == sizeof(uint8_t), "Sizes must match");
588
595 uint8_t value;
596 constexpr Float8E5M2FNUZ_t() noexcept : value(0) {}
597 constexpr Float8E5M2FNUZ_t(uint8_t v) noexcept : value(v) {}
598 constexpr operator uint8_t() const noexcept { return value; }
599 // nan values are treated like any other value for operator ==, !=
600 constexpr bool operator==(const Float8E5M2FNUZ_t& rhs) const noexcept { return value == rhs.value; };
601 constexpr bool operator!=(const Float8E5M2FNUZ_t& rhs) const noexcept { return value != rhs.value; };
602};
603
604static_assert(sizeof(Float8E5M2FNUZ_t) == sizeof(uint8_t), "Sizes must match");
605
606namespace detail {
607// This is used internally by the C++ API. This macro is to make it easy to generate overloaded methods for all of the various OrtRelease* functions for every Ort* type
608// This can't be done in the C API since C doesn't have function overloading.
609#define ORT_DEFINE_RELEASE(NAME) \
610 inline void OrtRelease(Ort##NAME* ptr) { GetApi().Release##NAME(ptr); }
611
612#define ORT_DEFINE_RELEASE_FROM_API_STRUCT(NAME, API_GETTER) \
613 inline void OrtRelease(Ort##NAME* ptr) { API_GETTER().Release##NAME(ptr); }
614
615ORT_DEFINE_RELEASE(Allocator);
616ORT_DEFINE_RELEASE(ArenaCfg);
617ORT_DEFINE_RELEASE(CustomOpDomain);
618ORT_DEFINE_RELEASE(Env);
619ORT_DEFINE_RELEASE(ExternalInitializerInfo);
620ORT_DEFINE_RELEASE(Graph);
621ORT_DEFINE_RELEASE(IoBinding);
622ORT_DEFINE_RELEASE(KernelInfo);
623ORT_DEFINE_RELEASE(KeyValuePairs);
624ORT_DEFINE_RELEASE(LoraAdapter);
625ORT_DEFINE_RELEASE(MemoryInfo);
626ORT_DEFINE_RELEASE(MapTypeInfo);
627ORT_DEFINE_RELEASE(Model);
628ORT_DEFINE_RELEASE(ModelMetadata);
629ORT_DEFINE_RELEASE(Node);
630ORT_DEFINE_RELEASE(Op);
631ORT_DEFINE_RELEASE(OpAttr);
632ORT_DEFINE_RELEASE(PrepackedWeightsContainer);
633ORT_DEFINE_RELEASE(RunOptions);
634ORT_DEFINE_RELEASE(Session);
635ORT_DEFINE_RELEASE(SessionOptions);
636ORT_DEFINE_RELEASE(SequenceTypeInfo);
637ORT_DEFINE_RELEASE(Status);
638ORT_DEFINE_RELEASE(SyncStream);
639ORT_DEFINE_RELEASE(TensorTypeAndShapeInfo);
640ORT_DEFINE_RELEASE(ThreadingOptions);
641ORT_DEFINE_RELEASE(TypeInfo);
642ORT_DEFINE_RELEASE(Value);
643ORT_DEFINE_RELEASE(ValueInfo);
644
645ORT_DEFINE_RELEASE_FROM_API_STRUCT(ModelCompilationOptions, GetCompileApi);
646ORT_DEFINE_RELEASE_FROM_API_STRUCT(EpDevice, GetEpApi);
647ORT_DEFINE_RELEASE_FROM_API_STRUCT(KernelDef, GetEpApi);
648ORT_DEFINE_RELEASE_FROM_API_STRUCT(KernelDefBuilder, GetEpApi);
649ORT_DEFINE_RELEASE_FROM_API_STRUCT(KernelRegistry, GetEpApi);
650
651// This is defined explicitly since OrtTensorRTProviderOptionsV2 is not a C API type,
652// but the struct has V2 in its name to indicate that it is the second version of the options.
655
656#undef ORT_DEFINE_RELEASE
657#undef ORT_DEFINE_RELEASE_FROM_API_STRUCT
658
662template <typename T>
663struct Unowned {
664 using Type = T;
665};
666
686template <typename T>
687struct Base {
688 using contained_type = T;
689
690 constexpr Base() = default;
691 constexpr explicit Base(contained_type* p) noexcept : p_{p} {}
693 OrtRelease(p_);
694 }
695
696 Base(const Base&) = delete;
697 Base& operator=(const Base&) = delete;
698
699 Base(Base&& v) noexcept : p_{v.p_} { v.p_ = nullptr; }
700 Base& operator=(Base&& v) noexcept {
701 OrtRelease(p_);
702 p_ = v.release();
703 return *this;
704 }
705
706 constexpr operator contained_type*() const noexcept { return p_; }
707 constexpr contained_type& operator*() const noexcept { return *p_; }
708
712 T* p = p_;
713 p_ = nullptr;
714 return p;
715 }
716
717 protected:
719};
720
721// Undefined. For const types use Base<Unowned<const T>>
722template <typename T>
723struct Base<const T>;
724
732template <typename T>
733struct Base<Unowned<T>> {
735
736 constexpr Base() = default;
737 constexpr explicit Base(contained_type* p) noexcept : p_{p} {}
738
739 ~Base() = default;
740
741 Base(const Base&) = default;
742 Base& operator=(const Base&) = default;
743
744 Base(Base&& v) noexcept : p_{v.p_} { v.p_ = nullptr; }
745 Base& operator=(Base&& v) noexcept {
746 p_ = nullptr;
747 std::swap(p_, v.p_);
748 return *this;
749 }
750
751 constexpr operator contained_type*() const noexcept { return p_; }
752 constexpr contained_type& operator*() const noexcept { return *p_; }
753
754 protected:
756};
757
758// Light functor to release memory with OrtAllocator
761 explicit AllocatedFree(OrtAllocator* allocator)
762 : allocator_(allocator) {}
763 void operator()(void* ptr) const {
764 if (ptr) allocator_->Free(allocator_, ptr);
765 }
766};
767
768} // namespace detail
769
770struct AllocatorWithDefaultOptions;
771struct Env;
772struct EpDevice;
773struct ExternalInitializerInfo;
774struct Graph;
775struct Model;
776struct Node;
777struct ModelMetadata;
778struct TypeInfo;
779struct PrepackedWeightsContainer;
780struct Session;
781struct SessionOptions;
782struct SyncStream;
783struct TensorRTProviderOptions;
784struct Value;
785struct ValueInfo;
786
791using AllocatedStringPtr = std::unique_ptr<char, detail::AllocatedFree>;
792
797struct Status : detail::Base<OrtStatus> {
798 Status() = default; // Same as with std::nullptr_t. But can be used in re-sizable containers and represent success.
799 explicit Status(std::nullptr_t) noexcept {}
800 explicit Status(OrtStatus* status) noexcept;
801 explicit Status(const Exception&);
802 explicit Status(const std::exception&);
803 Status(const char* message, OrtErrorCode code);
804 std::string GetErrorMessage() const;
806 bool IsOK() const noexcept;
807};
808
838
843struct TensorRTProviderOptions : detail::Base<OrtTensorRTProviderOptionsV2> {
844 TensorRTProviderOptions(std::nullptr_t) {}
848 void Update(const std::unordered_map<std::string, std::string>& options);
850 void UpdateWithValue(const char* key, void* value);
851
853 void* GetOptionByName(const char* name) const;
856};
857
862struct CUDAProviderOptions : detail::Base<OrtCUDAProviderOptionsV2> {
863 CUDAProviderOptions(std::nullptr_t) {}
867 void Update(const std::unordered_map<std::string, std::string>& options);
871 void UpdateWithValue(const char* key, void* value);
873 void* GetOptionByName(const char* name) const;
874};
875
890
891namespace detail {
892template <typename T>
894 using B = Base<T>;
895 using B::B;
896
897 // Wraps OrtApi::ExternalInitializerInfo_GetFilePath
898 const std::basic_string<ORTCHAR_T> GetFilePath() const;
899 // Wraps OrtApi::ExternalInitializerInfo_GetFileOffset
900 int64_t GetFileOffset() const;
901 // Wraps OrtApi::ExternalInitializerInfo_GetByteSize
902 size_t GetByteSize() const;
903};
904} // namespace detail
905
906// Const object holder that does not own the underlying object
909
915 using Base::Base;
916
917 explicit ExternalInitializerInfo(std::nullptr_t) {}
919 : detail::ConstExternalInitializerInfoImpl<OrtExternalInitializerInfo>{p} {}
920
922
924 ExternalInitializerInfo(const ORTCHAR_T* filepath, int64_t file_offset, size_t byte_size);
925
927 static Status Create(const ORTCHAR_T* filepath, int64_t file_offset, size_t byte_size,
928 /*out*/ ExternalInitializerInfo& out);
929};
930
931namespace detail {
932template <typename T>
935 using B::B;
936
937 const char* GetValue(const char* key) const;
938
939 // get the pairs in unordered_map. needs to copy to std::string so the hash works as expected
940 std::unordered_map<std::string, std::string> GetKeyValuePairs() const;
941 // get the pairs in two vectors. entries will be 1:1 between keys and values. avoids copying to std::string
942 void GetKeyValuePairs(std::vector<const char*>& keys, std::vector<const char*>& values) const;
943};
944} // namespace detail
945
946// Const object holder that does not own the underlying object
948
950struct KeyValuePairs : detail::KeyValuePairsImpl<OrtKeyValuePairs> {
951 explicit KeyValuePairs(std::nullptr_t) {}
953 explicit KeyValuePairs(OrtKeyValuePairs* p) : KeyValuePairsImpl<OrtKeyValuePairs>{p} {}
954
956 explicit KeyValuePairs();
957
959 explicit KeyValuePairs(const std::unordered_map<std::string, std::string>& kv_pairs);
960
962 void Add(const char* key, const char* value);
963
965 void Remove(const char* key);
966
967 ConstKeyValuePairs GetConst() const { return ConstKeyValuePairs{this->p_}; }
968};
969
970namespace detail {
971template <typename T>
972struct MemoryInfoImpl : Base<T> {
973 using B = Base<T>;
974 using B::B;
975
976 std::string GetAllocatorName() const;
978 int GetDeviceId() const;
982 uint32_t GetVendorId() const;
983
984 template <typename U>
985 bool operator==(const MemoryInfoImpl<U>& o) const;
986};
987} // namespace detail
988
989// Const object holder that does not own the underlying object
991
995struct MemoryInfo : detail::MemoryInfoImpl<OrtMemoryInfo> {
997 explicit MemoryInfo(std::nullptr_t) {}
998 explicit MemoryInfo(OrtMemoryInfo* p) : MemoryInfoImpl<OrtMemoryInfo>{p} {}
999 MemoryInfo(const char* name, OrtAllocatorType type, int id, OrtMemType mem_type);
1000 MemoryInfo(const char* name, OrtMemoryInfoDeviceType device_type, uint32_t vendor_id, uint32_t device_id,
1001 OrtDeviceMemoryType mem_type, size_t alignment, OrtAllocatorType allocator_type);
1002 ConstMemoryInfo GetConst() const { return ConstMemoryInfo{this->p_}; }
1003};
1004
1012 MemoryAllocation(OrtAllocator* allocator, void* p, size_t size);
1017 MemoryAllocation& operator=(MemoryAllocation&&) noexcept;
1018
1019 void* get() { return p_; }
1020 size_t size() const { return size_; }
1021
1022 private:
1023 OrtAllocator* allocator_;
1024 void* p_;
1025 size_t size_;
1026};
1027
1028namespace detail {
1029template <typename T>
1030struct AllocatorImpl : Base<T> {
1031 using B = Base<T>;
1032 using B::B;
1033
1034 void* Alloc(size_t size);
1035 MemoryAllocation GetAllocation(size_t size);
1036 void Free(void* p);
1037 ConstMemoryInfo GetInfo() const;
1038
1043 KeyValuePairs GetStats() const;
1044};
1045} // namespace detail
1046
1050struct AllocatorWithDefaultOptions : detail::AllocatorImpl<detail::Unowned<OrtAllocator>> {
1051 explicit AllocatorWithDefaultOptions(std::nullptr_t) {}
1053};
1054
1059struct Allocator : detail::AllocatorImpl<OrtAllocator> {
1060 explicit Allocator(std::nullptr_t) {}
1061 Allocator(const Session& session, const OrtMemoryInfo*);
1062
1064 explicit Allocator(OrtAllocator* p) : AllocatorImpl<OrtAllocator>{p} {}
1065};
1066
1067using UnownedAllocator = detail::AllocatorImpl<detail::Unowned<OrtAllocator>>;
1068
1073namespace detail {
1074template <typename T>
1076 using B = Base<T>;
1077 using B::B;
1078 // For some reason this is not a const method on the stream
1079 void* GetHandle();
1080};
1081} // namespace detail
1082
1083struct SyncStream : detail::SyncStreamImpl<OrtSyncStream> {
1085 explicit SyncStream(std::nullptr_t) {}
1087 explicit SyncStream(OrtSyncStream* p) : SyncStreamImpl<OrtSyncStream>{p} {}
1088};
1089
1091
1092namespace detail {
1093template <typename T>
1096 using B::B;
1097
1099 uint32_t VendorId() const;
1100 uint32_t DeviceId() const;
1101 const char* Vendor() const;
1103};
1104} // namespace detail
1105
1110
1111namespace detail {
1112template <typename T>
1115 using B::B;
1116
1117 const char* EpName() const;
1118 const char* EpVendor() const;
1124};
1125} // namespace detail
1126
1131
1134struct EpDevice : detail::EpDeviceImpl<OrtEpDevice> {
1135 explicit EpDevice(std::nullptr_t) {}
1136 explicit EpDevice(OrtEpDevice* p) : EpDeviceImpl<OrtEpDevice>{p} {}
1137
1139 EpDevice(OrtEpFactory& ep_factory, ConstHardwareDevice& hardware_device,
1140 ConstKeyValuePairs ep_metadata = {}, ConstKeyValuePairs ep_options = {});
1141};
1142
1150 const std::vector<ConstEpDevice>& ep_devices,
1151 const char* compatibility_info);
1152
1158struct Env : detail::Base<OrtEnv> {
1159 explicit Env(std::nullptr_t) {}
1160
1162 Env(OrtLoggingLevel logging_level = ORT_LOGGING_LEVEL_WARNING, _In_ const char* logid = "");
1163
1165 Env(OrtLoggingLevel logging_level, const char* logid, OrtLoggingFunction logging_function, void* logger_param);
1166
1168 Env(const OrtThreadingOptions* tp_options, OrtLoggingLevel logging_level = ORT_LOGGING_LEVEL_WARNING, _In_ const char* logid = "");
1169
1171 Env(const OrtThreadingOptions* tp_options, OrtLoggingFunction logging_function, void* logger_param,
1172 OrtLoggingLevel logging_level = ORT_LOGGING_LEVEL_WARNING, _In_ const char* logid = "");
1173
1175 explicit Env(OrtEnv* p) : Base<OrtEnv>{p} {}
1176
1179
1181
1182 Env& CreateAndRegisterAllocator(const OrtMemoryInfo* mem_info, const OrtArenaCfg* arena_cfg);
1183
1184 Env& CreateAndRegisterAllocatorV2(const std::string& provider_type, const OrtMemoryInfo* mem_info,
1185 const std::unordered_map<std::string, std::string>& options,
1186 const OrtArenaCfg* arena_cfg);
1187
1189
1191
1193 OrtAllocatorType allocator_type,
1194 const OrtKeyValuePairs* allocator_options);
1195
1196 // Result may be nullptr
1198
1200 OrtDeviceMemoryType mem_type);
1201
1202 Env& RegisterExecutionProviderLibrary(const char* registration_name, const std::basic_string<ORTCHAR_T>& path);
1203 Env& UnregisterExecutionProviderLibrary(const char* registration_name);
1204
1205 std::vector<ConstEpDevice> GetEpDevices() const;
1206
1207 Status CopyTensors(const std::vector<Value>& src_tensors,
1208 const std::vector<Value>& dst_tensors,
1209 OrtSyncStream* stream) const;
1210};
1211
1215struct CustomOpDomain : detail::Base<OrtCustomOpDomain> {
1217 using Base::Base;
1218
1219 explicit CustomOpDomain(std::nullptr_t) {}
1220
1222 explicit CustomOpDomain(const char* domain);
1223
1224 // This does not take ownership of the op, simply registers it.
1225 void Add(const OrtCustomOp* op);
1226};
1227
1229struct LoraAdapter : detail::Base<OrtLoraAdapter> {
1231 using Base::Base;
1232
1233 explicit LoraAdapter(std::nullptr_t) {}
1240 static LoraAdapter CreateLoraAdapter(const std::basic_string<ORTCHAR_T>& adapter_path,
1241 OrtAllocator* allocator);
1242
1250 static LoraAdapter CreateLoraAdapterFromArray(const void* bytes, size_t num_bytes,
1251 OrtAllocator* allocator);
1252};
1253
1257struct RunOptions : detail::Base<OrtRunOptions> {
1258 explicit RunOptions(std::nullptr_t) {}
1260
1263
1266
1267 RunOptions& SetRunTag(const char* run_tag);
1268 const char* GetRunTag() const;
1269
1270 RunOptions& AddConfigEntry(const char* config_key, const char* config_value);
1271 const char* GetConfigEntry(const char* config_key);
1272
1279
1285
1293};
1294
1295namespace detail {
1296// Utility function that returns a SessionOption config entry key for a specific custom operator.
1297// Ex: custom_op.[custom_op_name].[config]
1298std::string MakeCustomOpConfigEntryKey(const char* custom_op_name, const char* config);
1299} // namespace detail
1300
1311 CustomOpConfigs() = default;
1312 ~CustomOpConfigs() = default;
1317
1326 CustomOpConfigs& AddConfig(const char* custom_op_name, const char* config_key, const char* config_value);
1327
1336 const std::unordered_map<std::string, std::string>& GetFlattenedConfigs() const;
1337
1338 private:
1339 std::unordered_map<std::string, std::string> flat_configs_;
1340};
1341
1347namespace detail {
1348// we separate const-only methods because passing const ptr to non-const methods
1349// is only discovered when inline methods are compiled which is counter-intuitive
1350template <typename T>
1351struct ConstSessionOptionsImpl : Base<T> {
1352 using B = Base<T>;
1353 using B::B;
1354
1355 SessionOptions Clone() const;
1356
1357 std::string GetConfigEntry(const char* config_key) const;
1358 bool HasConfigEntry(const char* config_key) const;
1359 std::string GetConfigEntryOrDefault(const char* config_key, const std::string& def) const;
1360};
1361
1362template <typename T>
1363struct SessionOptionsImpl : ConstSessionOptionsImpl<T> {
1364 using B = ConstSessionOptionsImpl<T>;
1365 using B::B;
1366
1367 SessionOptionsImpl& SetIntraOpNumThreads(int intra_op_num_threads);
1368 SessionOptionsImpl& SetInterOpNumThreads(int inter_op_num_threads);
1369 SessionOptionsImpl& SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level);
1370 SessionOptionsImpl& SetDeterministicCompute(bool value);
1371
1372 SessionOptionsImpl& EnableCpuMemArena();
1373 SessionOptionsImpl& DisableCpuMemArena();
1374
1375 SessionOptionsImpl& SetOptimizedModelFilePath(const ORTCHAR_T* optimized_model_file);
1376
1377 SessionOptionsImpl& EnableProfiling(const ORTCHAR_T* profile_file_prefix);
1378 SessionOptionsImpl& DisableProfiling();
1379
1380 SessionOptionsImpl& EnableOrtCustomOps();
1381
1382 SessionOptionsImpl& EnableMemPattern();
1383 SessionOptionsImpl& DisableMemPattern();
1384
1385 SessionOptionsImpl& SetExecutionMode(ExecutionMode execution_mode);
1386
1387 SessionOptionsImpl& SetLoadCancellationFlag(bool value);
1388
1389 SessionOptionsImpl& SetLogId(const char* logid);
1390 SessionOptionsImpl& SetLogSeverityLevel(int level);
1391
1392 SessionOptionsImpl& Add(OrtCustomOpDomain* custom_op_domain);
1393
1394 SessionOptionsImpl& DisablePerSessionThreads();
1395
1396 SessionOptionsImpl& AddConfigEntry(const char* config_key, const char* config_value);
1397
1398 SessionOptionsImpl& AddInitializer(const char* name, const OrtValue* ort_val);
1399 SessionOptionsImpl& AddExternalInitializers(const std::vector<std::string>& names, const std::vector<Value>& ort_values);
1400 SessionOptionsImpl& AddExternalInitializersFromFilesInMemory(const std::vector<std::basic_string<ORTCHAR_T>>& external_initializer_file_names,
1401 const std::vector<char*>& external_initializer_file_buffer_array,
1402 const std::vector<size_t>& external_initializer_file_lengths);
1403
1404 SessionOptionsImpl& AppendExecutionProvider_CPU(int use_arena);
1405 SessionOptionsImpl& AppendExecutionProvider_CUDA(const OrtCUDAProviderOptions& provider_options);
1406 SessionOptionsImpl& AppendExecutionProvider_CUDA_V2(const OrtCUDAProviderOptionsV2& provider_options);
1407 SessionOptionsImpl& AppendExecutionProvider_ROCM(const OrtROCMProviderOptions& provider_options);
1408 SessionOptionsImpl& AppendExecutionProvider_OpenVINO(const OrtOpenVINOProviderOptions& provider_options);
1410 SessionOptionsImpl& AppendExecutionProvider_OpenVINO_V2(const std::unordered_map<std::string, std::string>& provider_options = {});
1411 SessionOptionsImpl& AppendExecutionProvider_TensorRT(const OrtTensorRTProviderOptions& provider_options);
1412 SessionOptionsImpl& AppendExecutionProvider_TensorRT_V2(const OrtTensorRTProviderOptionsV2& provider_options);
1413 SessionOptionsImpl& AppendExecutionProvider_MIGraphX(const OrtMIGraphXProviderOptions& provider_options);
1415 SessionOptionsImpl& AppendExecutionProvider_CANN(const OrtCANNProviderOptions& provider_options);
1417 SessionOptionsImpl& AppendExecutionProvider_Dnnl(const OrtDnnlProviderOptions& provider_options);
1419 SessionOptionsImpl& AppendExecutionProvider(const std::string& provider_name,
1420 const std::unordered_map<std::string, std::string>& provider_options = {});
1421
1424 SessionOptionsImpl& AppendExecutionProvider_V2(Env& env, const std::vector<ConstEpDevice>& ep_devices,
1425 const KeyValuePairs& ep_options);
1428 SessionOptionsImpl& AppendExecutionProvider_V2(Env& env, const std::vector<ConstEpDevice>& ep_devices,
1429 const std::unordered_map<std::string, std::string>& ep_options);
1430
1432 SessionOptionsImpl& SetEpSelectionPolicy(OrtExecutionProviderDevicePolicy policy);
1433
1435 SessionOptionsImpl& SetEpSelectionPolicy(EpSelectionDelegate delegate, void* state = nullptr);
1436
1437 SessionOptionsImpl& SetCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn);
1438 SessionOptionsImpl& SetCustomThreadCreationOptions(void* ort_custom_thread_creation_options);
1439 SessionOptionsImpl& SetCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn);
1440
1444 SessionOptionsImpl& RegisterCustomOpsLibrary(const ORTCHAR_T* library_name, const CustomOpConfigs& custom_op_configs = {});
1445
1446 SessionOptionsImpl& RegisterCustomOpsUsingFunction(const char* function_name);
1447
1449 SessionOptionsImpl& AppendExecutionProvider_VitisAI(const std::unordered_map<std::string, std::string>& provider_options = {});
1450
1452 SessionOptionsImpl& AddFreeDimensionOverride(const char* dim_denotation, int64_t dim_value);
1453
1455 SessionOptionsImpl& AddFreeDimensionOverrideByName(const char* dim_name, int64_t dim_value);
1456};
1457} // namespace detail
1458
1459using UnownedSessionOptions = detail::SessionOptionsImpl<detail::Unowned<OrtSessionOptions>>;
1460using ConstSessionOptions = detail::ConstSessionOptionsImpl<detail::Unowned<const OrtSessionOptions>>;
1461
1465struct SessionOptions : detail::SessionOptionsImpl<OrtSessionOptions> {
1466 explicit SessionOptions(std::nullptr_t) {}
1468 explicit SessionOptions(OrtSessionOptions* p) : SessionOptionsImpl<OrtSessionOptions>{p} {}
1471};
1472
1477struct ModelCompilationOptions : detail::Base<OrtModelCompilationOptions> {
1479 using Base::Base;
1480
1481 explicit ModelCompilationOptions(std::nullptr_t) {}
1482
1483 ModelCompilationOptions(const Env& env, const SessionOptions& session_options);
1484 ModelCompilationOptions(const Env& env, ConstSessionOptions session_options);
1485
1486 ModelCompilationOptions& SetInputModelPath(const ORTCHAR_T* input_model_path);
1488 size_t input_model_data_size);
1489 ModelCompilationOptions& SetEpContextEmbedMode(bool embed_ep_context_in_model);
1490 ModelCompilationOptions& SetOutputModelPath(const ORTCHAR_T* output_model_path);
1492 size_t initializer_size_threshold);
1493
1496 OrtGetInitializerLocationFunc get_initializer_location_func,
1497 void* state);
1498
1499 ModelCompilationOptions& SetOutputModelBuffer(OrtAllocator* allocator, void** output_model_buffer_ptr,
1500 size_t* output_model_buffer_size_ptr);
1501
1504
1505 ModelCompilationOptions& SetEpContextBinaryInformation(const ORTCHAR_T* output_directory,
1506 const ORTCHAR_T* model_name);
1508
1510};
1511
1518Status CompileModel(const Env& env, const ModelCompilationOptions& model_compilation_options);
1519
1523struct ModelMetadata : detail::Base<OrtModelMetadata> {
1525 using Base::Base;
1526
1527 explicit ModelMetadata(std::nullptr_t) {}
1528
1536
1544
1552
1560
1568
1575 std::vector<AllocatedStringPtr> GetCustomMetadataMapKeysAllocated(OrtAllocator* allocator) const;
1576
1587
1588 int64_t GetVersion() const;
1589};
1590
1591struct IoBinding;
1592
1593namespace detail {
1594
1595// we separate const-only methods because passing const ptr to non-const methods
1596// is only discovered when inline methods are compiled which is counter-intuitive
1597template <typename T>
1599 using B = Base<T>;
1600 using B::B;
1601
1602 size_t GetInputCount() const;
1603 size_t GetOutputCount() const;
1605
1606 std::vector<std::string> GetInputNames() const;
1607 std::vector<std::string> GetOutputNames() const;
1608 std::vector<std::string> GetOverridableInitializerNames() const;
1609
1610 std::vector<ConstMemoryInfo> GetMemoryInfoForInputs() const;
1611 std::vector<ConstMemoryInfo> GetMemoryInfoForOutputs() const;
1612 std::vector<ConstEpDevice> GetEpDeviceForInputs() const;
1613
1622
1631
1640
1641 uint64_t GetProfilingStartTimeNs() const;
1643
1644 TypeInfo GetInputTypeInfo(size_t index) const;
1645 TypeInfo GetOutputTypeInfo(size_t index) const;
1647
1648 int GetOpset(const std::string& domain) const;
1649
1650 // Will move before checkin if that's the case.
1651 std::vector<ValueInfo> GetInputs() const;
1652 std::vector<ValueInfo> GetOutputs() const;
1653};
1654
1655template <typename T>
1658 using B::B;
1659
1677 std::vector<Value> Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
1678 const char* const* output_names, size_t output_count);
1679
1683 void Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
1684 const char* const* output_names, Value* output_values, size_t output_count);
1685
1686 void Run(const RunOptions& run_options, const IoBinding&);
1687
1707 void RunAsync(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
1708 const char* const* output_names, Value* output_values, size_t output_count, RunAsyncCallbackFn callback, void* user_data);
1709
1717
1729 void SetEpDynamicOptions(const char* const* keys, const char* const* values, size_t kv_len);
1730
1731 void FinalizeModelEditorSession(const Model& model, const SessionOptions& options,
1732 OrtPrepackedWeightsContainer* prepacked_weights_container = nullptr);
1733};
1734
1735} // namespace detail
1736
1739
1743struct Session : detail::SessionImpl<OrtSession> {
1745 explicit Session(std::nullptr_t) {}
1746 explicit Session(OrtSession* p) : SessionImpl{p} {}
1747
1748 Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options);
1749
1751 Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options,
1752 OrtPrepackedWeightsContainer* prepacked_weights_container);
1753
1755 Session(const Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options);
1756
1758 Session(const Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options,
1759 OrtPrepackedWeightsContainer* prepacked_weights_container);
1760
1761#if !defined(ORT_MINIMAL_BUILD)
1763 Session(const Env& env, const Model& model, const SessionOptions& options);
1764
1766 static Session CreateModelEditorSession(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options);
1767
1769 static Session CreateModelEditorSession(const Env& env, const void* model_data, size_t model_data_length,
1770 const SessionOptions& options);
1771#endif // !defined(ORT_MINIMAL_BUILD)
1772
1773 ConstSession GetConst() const { return ConstSession{this->p_}; }
1774 UnownedSession GetUnowned() const { return UnownedSession{this->p_}; }
1775};
1776
1777namespace detail {
1778template <typename T>
1780 using B = Base<T>;
1781 using B::B;
1782
1784 size_t GetElementCount() const;
1785
1786 size_t GetDimensionsCount() const;
1787
1792 [[deprecated("use GetShape()")]] void GetDimensions(int64_t* values, size_t values_count) const;
1793
1794 void GetSymbolicDimensions(const char** values, size_t values_count) const;
1795 std::vector<const char*> GetSymbolicDimensions() const;
1796
1797 bool HasShape() const;
1798 std::vector<int64_t> GetShape() const;
1799};
1800
1801} // namespace detail
1802
1804
1810 using Base::Base;
1811
1813 explicit TensorTypeAndShapeInfo(std::nullptr_t) {}
1815 explicit TensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* p) : TensorTypeAndShapeInfoImpl{p} {}
1816
1817 // Create a TensorTypeAndShapeInfo object with the specified element type and dimensions
1818 // symbolic_dims are optional, but should be 1:1 with dims.
1819 // The value in symbolic_dims will be used for all entries in dims that are -1.
1821 const std::vector<int64_t>& dims,
1822 const std::vector<std::string>* symbolic_dims = nullptr);
1823
1825};
1826
1827namespace detail {
1828template <typename T>
1830 using B = Base<T>;
1831 using B::B;
1833};
1834
1835} // namespace detail
1836
1838
1842struct SequenceTypeInfo : detail::SequenceTypeInfoImpl<OrtSequenceTypeInfo> {
1844 using Base::Base;
1845
1846 explicit SequenceTypeInfo(std::nullptr_t) {}
1847 explicit SequenceTypeInfo(OrtSequenceTypeInfo* p) : SequenceTypeInfoImpl<OrtSequenceTypeInfo>{p} {}
1849};
1850
1851namespace detail {
1852template <typename T>
1854 using B = Base<T>;
1855 using B::B;
1857};
1858
1859} // namespace detail
1860
1861// This is always owned by the TypeInfo and can only be obtained from it.
1863
1864namespace detail {
1865template <typename T>
1872
1873} // namespace detail
1874
1876
1880struct MapTypeInfo : detail::MapTypeInfoImpl<OrtMapTypeInfo> {
1882 using Base::Base;
1883
1884 explicit MapTypeInfo(std::nullptr_t) {}
1885 explicit MapTypeInfo(OrtMapTypeInfo* p) : MapTypeInfoImpl<OrtMapTypeInfo>{p} {}
1886 ConstMapTypeInfo GetConst() const { return ConstMapTypeInfo{this->p_}; }
1887};
1888
1889namespace detail {
1890template <typename T>
1902} // namespace detail
1903
1909
1914struct TypeInfo : detail::TypeInfoImpl<OrtTypeInfo> {
1916 using Base::Base;
1917
1919 explicit TypeInfo(std::nullptr_t) {}
1920 explicit TypeInfo(OrtTypeInfo* p) : TypeInfoImpl<OrtTypeInfo>{p} {}
1921
1922#if !defined(ORT_MINIMAL_BUILD)
1928#endif // !defined(ORT_MINIMAL_BUILD)
1929
1930 ConstTypeInfo GetConst() const { return ConstTypeInfo{this->p_}; }
1931};
1932
1933namespace detail {
1934// This structure is used to feed sparse tensor values
1935// information for use with FillSparseTensor<Format>() API
1936// if the data type for the sparse tensor values is numeric
1937// use data.p_data, otherwise, use data.str pointer to feed
1938// values. data.str is an array of const char* that are zero terminated.
1939// number of strings in the array must match shape size.
1940// For fully sparse tensors use shape {0} and set p_data/str
1941// to nullptr.
1943 const int64_t* values_shape;
1945 union {
1946 const void* p_data;
1947 const char** str;
1948 } data;
1949};
1950
1951// Provides a way to pass shape in a single
1952// argument
1953struct Shape {
1954 const int64_t* shape;
1956};
1957
1958template <typename T>
1960 using B = Base<T>;
1961 using B::B;
1962
1966 template <typename R>
1967 void GetOpaqueData(const char* domain, const char* type_name, R&) const;
1968
1969 bool IsTensor() const;
1970 bool HasValue() const;
1971
1972 size_t GetCount() const; // If a non tensor, returns 2 for map and N for sequence, where N is the number of elements
1973 Value GetValue(int index, OrtAllocator* allocator) const;
1974
1982
1997 void GetStringTensorContent(void* buffer, size_t buffer_length, size_t* offsets, size_t offsets_count) const;
1998
2005 template <typename R>
2006 const R* GetTensorData() const;
2007
2012 const void* GetTensorRawData() const;
2013
2021
2029
2035
2044 void GetStringTensorElement(size_t buffer_length, size_t element_index, void* buffer) const;
2045
2052 std::string GetStringTensorElement(size_t element_index) const;
2053
2060 size_t GetStringTensorElementLength(size_t element_index) const;
2061
2068 size_t GetTensorSizeInBytes() const;
2069
2070#if !defined(DISABLE_SPARSE_TENSORS)
2078
2085
2094
2104 template <typename R>
2105 const R* GetSparseTensorIndicesData(OrtSparseIndicesFormat indices_format, size_t& num_indices) const;
2106
2111 bool IsSparseTensor() const;
2112
2121 template <typename R>
2122 const R* GetSparseTensorValues() const;
2123
2124#endif
2125};
2126
2127template <typename T>
2130 using B::B;
2131
2137 template <typename R>
2139
2145
2147 // Obtain a reference to an element of data at the location specified
2153 template <typename R>
2154 R& At(const std::vector<int64_t>& location);
2155
2161 void FillStringTensor(const char* const* s, size_t s_len);
2162
2168 void FillStringTensorElement(const char* s, size_t index);
2169
2182 char* GetResizedStringTensorElementBuffer(size_t index, size_t buffer_length);
2183
2184#if !defined(DISABLE_SPARSE_TENSORS)
2193 void UseCooIndices(int64_t* indices_data, size_t indices_num);
2194
2205 void UseCsrIndices(int64_t* inner_data, size_t inner_num, int64_t* outer_data, size_t outer_num);
2206
2215 void UseBlockSparseIndices(const Shape& indices_shape, int32_t* indices_data);
2216
2226 void FillSparseTensorCoo(const OrtMemoryInfo* data_mem_info, const OrtSparseValuesParam& values_param,
2227 const int64_t* indices_data, size_t indices_num);
2228
2240 void FillSparseTensorCsr(const OrtMemoryInfo* data_mem_info,
2241 const OrtSparseValuesParam& values,
2242 const int64_t* inner_indices_data, size_t inner_indices_num,
2243 const int64_t* outer_indices_data, size_t outer_indices_num);
2244
2255 const OrtSparseValuesParam& values,
2256 const Shape& indices_shape,
2257 const int32_t* indices_data);
2258
2259#endif
2260};
2261
2262} // namespace detail
2263
2266
2270struct Value : detail::ValueImpl<OrtValue> {
2272 using Base::Base;
2275
2276 Value(std::nullptr_t) {}
2277 Value(Value&&) = default;
2278 Value& operator=(Value&&) = default;
2279
2280 ConstValue GetConst() const { return ConstValue{this->p_}; }
2281 UnownedValue GetUnowned() const { return UnownedValue{this->p_}; }
2282
2291 template <typename T>
2292 static Value CreateTensor(const OrtMemoryInfo* info, T* p_data, size_t p_data_element_count,
2293 const int64_t* shape, size_t shape_len);
2294
2304 static Value CreateTensor(const OrtMemoryInfo* info, void* p_data, size_t p_data_byte_count,
2305 const int64_t* shape, size_t shape_len,
2307
2317 static Value CreateTensor(OrtAllocator* deleter, void* p_data, size_t p_data_byte_count,
2318 const int64_t* shape, size_t shape_len,
2320
2332 template <typename T>
2333 static Value CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len);
2334
2346 static Value CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len,
2348
2357 static Value CreateMap(const Value& keys, const Value& values);
2358
2366 static Value CreateSequence(const std::vector<Value>& values);
2367
2376 template <typename T>
2377 static Value CreateOpaque(const char* domain, const char* type_name, const T& value);
2378
2379#if !defined(DISABLE_SPARSE_TENSORS)
2390 template <typename T>
2391 static Value CreateSparseTensor(const OrtMemoryInfo* info, T* p_data, const Shape& dense_shape,
2392 const Shape& values_shape);
2393
2410 static Value CreateSparseTensor(const OrtMemoryInfo* info, void* p_data, const Shape& dense_shape,
2411 const Shape& values_shape, ONNXTensorElementDataType type);
2412
2422 template <typename T>
2423 static Value CreateSparseTensor(OrtAllocator* allocator, const Shape& dense_shape);
2424
2436 static Value CreateSparseTensor(OrtAllocator* allocator, const Shape& dense_shape, ONNXTensorElementDataType type);
2437
2438#endif // !defined(DISABLE_SPARSE_TENSORS)
2439};
2440
2441namespace detail {
2442namespace binding_utils {
2443// Bring these out of template
2444std::vector<std::string> GetOutputNamesHelper(const OrtIoBinding* binding, OrtAllocator*);
2445std::vector<Value> GetOutputValuesHelper(const OrtIoBinding* binding, OrtAllocator*);
2446} // namespace binding_utils
2447
2448template <typename T>
2450 using B = Base<T>;
2451 using B::B;
2452
2453 std::vector<std::string> GetOutputNames() const;
2454 std::vector<std::string> GetOutputNames(OrtAllocator*) const;
2455 std::vector<Value> GetOutputValues() const;
2456 std::vector<Value> GetOutputValues(OrtAllocator*) const;
2457};
2458
2459template <typename T>
2462 using B::B;
2463
2464 void BindInput(const char* name, const Value&);
2465 void BindOutput(const char* name, const Value&);
2466 void BindOutput(const char* name, const OrtMemoryInfo*);
2471};
2472
2473} // namespace detail
2474
2477
2481struct IoBinding : detail::IoBindingImpl<OrtIoBinding> {
2482 explicit IoBinding(std::nullptr_t) {}
2483 explicit IoBinding(Session& session);
2484 ConstIoBinding GetConst() const { return ConstIoBinding{this->p_}; }
2485 UnownedIoBinding GetUnowned() const { return UnownedIoBinding{this->p_}; }
2486};
2487
2492struct ArenaCfg : detail::Base<OrtArenaCfg> {
2493 explicit ArenaCfg(std::nullptr_t) {}
2502 ArenaCfg(size_t max_mem, int arena_extend_strategy, int initial_chunk_size_bytes, int max_dead_bytes_per_chunk);
2503
2508 explicit ArenaCfg(const std::unordered_map<std::string, size_t>& arena_config);
2509};
2510
2511//
2512// Custom OPs (only needed to implement custom OPs)
2513//
2514
2515namespace detail {
2516// Need to define a templated ConstOpAttr with const members
2517template <typename T>
2520 using B::B;
2521
2522 // Wraps OrtApi::OpAttr_GetName
2523 std::string GetName() const;
2524 // Wraps OrtApi::OpAttr_GetType
2526
2527 // Wraps OrtApi::ReadAttr for a single value
2528 // This does not support Tensor Attribute
2529 // Call GetTensorAttributeAsOrtValue() instead.
2530 template <typename R>
2531 Status GetValue(R& out) const;
2532
2533 // Wraps OrtApi::ReadAttr for an array of values
2534 template <typename R>
2535 Status GetValueArray(std::vector<R>& out) const;
2536 // Wraps OrtApi::OpAttr_GetTensorAttributeAsOrtValue
2538};
2539} // namespace detail
2540
2542
2546struct OpAttr : detail::ConstOpAttrImpl<OrtOpAttr> {
2548 using Base::Base;
2549
2550 OpAttr() = default; // Enable storing it in the container for resize()
2551 explicit OpAttr(std::nullptr_t) {}
2552 OpAttr(const char* name, const void* data, int len, OrtOpAttrType type);
2553
2554 ConstOpAttr GetConst() const { return ConstOpAttr{this->p_}; }
2555};
2556
2565#define ORT_CXX_LOG(logger, message_severity, message) \
2566 do { \
2567 if (message_severity >= logger.GetLoggingSeverityLevel()) { \
2568 Ort::ThrowOnError(logger.LogMessage(message_severity, ORT_FILE, __LINE__, \
2569 static_cast<const char*>(__FUNCTION__), message)); \
2570 } \
2571 } while (false)
2572
2581#define ORT_CXX_LOG_NOEXCEPT(logger, message_severity, message) \
2582 do { \
2583 if (message_severity >= logger.GetLoggingSeverityLevel()) { \
2584 static_cast<void>(logger.LogMessage(message_severity, ORT_FILE, __LINE__, \
2585 static_cast<const char*>(__FUNCTION__), message)); \
2586 } \
2587 } while (false)
2588
2600#define ORT_CXX_LOGF(logger, message_severity, /*format,*/...) \
2601 do { \
2602 if (message_severity >= logger.GetLoggingSeverityLevel()) { \
2603 Ort::ThrowOnError(logger.LogFormattedMessage(message_severity, ORT_FILE, __LINE__, \
2604 static_cast<const char*>(__FUNCTION__), __VA_ARGS__)); \
2605 } \
2606 } while (false)
2607
2619#define ORT_CXX_LOGF_NOEXCEPT(logger, message_severity, /*format,*/...) \
2620 do { \
2621 if (message_severity >= logger.GetLoggingSeverityLevel()) { \
2622 static_cast<void>(logger.LogFormattedMessage(message_severity, ORT_FILE, __LINE__, \
2623 static_cast<const char*>(__FUNCTION__), __VA_ARGS__)); \
2624 } \
2625 } while (false)
2626
2637struct Logger {
2641 Logger() = default;
2642
2646 explicit Logger(std::nullptr_t) {}
2647
2654 explicit Logger(const OrtLogger* logger);
2655
2656 ~Logger() = default;
2657
2658 Logger(const Logger&) = default;
2659 Logger& operator=(const Logger&) = default;
2660
2661 Logger(Logger&& v) noexcept = default;
2662 Logger& operator=(Logger&& v) noexcept = default;
2663
2670
2683 Status LogMessage(OrtLoggingLevel log_severity_level, const ORTCHAR_T* file_path, int line_number,
2684 const char* func_name, const char* message) const noexcept;
2685
2700 template <typename... Args>
2701 Status LogFormattedMessage(OrtLoggingLevel log_severity_level, const ORTCHAR_T* file_path, int line_number,
2702 const char* func_name, const char* format, Args&&... args) const noexcept;
2703
2704 private:
2705 const OrtLogger* logger_{};
2706 OrtLoggingLevel cached_severity_level_{};
2707};
2708
2717 size_t GetInputCount() const;
2718 size_t GetOutputCount() const;
2719 // If input is optional and is not present, the method returns an empty ConstValue
2720 // which can be compared to nullptr.
2721 ConstValue GetInput(size_t index) const;
2722 // If output is optional and is not present, the method returns an empty UnownedValue
2723 // which can be compared to nullptr.
2724 UnownedValue GetOutput(size_t index, const int64_t* dim_values, size_t dim_count) const;
2725 UnownedValue GetOutput(size_t index, const std::vector<int64_t>& dims) const;
2726 void* GetGPUComputeStream() const;
2728 Ort::Allocator GetAllocator(const OrtMemoryInfo& memory_info) const;
2729 OrtKernelContext* GetOrtKernelContext() const { return ctx_; }
2730 void ParallelFor(void (*fn)(void*, size_t), size_t total, size_t num_batch, void* usr_data) const;
2731
2732 private:
2733 OrtKernelContext* ctx_;
2734};
2735
2736struct KernelInfo;
2737
2738namespace detail {
2739namespace attr_utils {
2740void GetAttr(const OrtKernelInfo* p, const char* name, float&);
2741void GetAttr(const OrtKernelInfo* p, const char* name, int64_t&);
2742void GetAttr(const OrtKernelInfo* p, const char* name, std::string&);
2743void GetAttrs(const OrtKernelInfo* p, const char* name, std::vector<float>&);
2744void GetAttrs(const OrtKernelInfo* p, const char* name, std::vector<int64_t>&);
2745} // namespace attr_utils
2746
2747template <typename T>
2748struct KernelInfoImpl : Base<T> {
2749 using B = Base<T>;
2750 using B::B;
2751
2752 KernelInfo Copy() const;
2753
2754 template <typename R> // R is only implemented for float, int64_t, and string
2755 R GetAttribute(const char* name) const {
2756 R val;
2757 attr_utils::GetAttr(this->p_, name, val);
2758 return val;
2759 }
2760
2761 template <typename R> // R is only implemented for std::vector<float>, std::vector<int64_t>
2762 std::vector<R> GetAttributes(const char* name) const {
2763 std::vector<R> result;
2764 attr_utils::GetAttrs(this->p_, name, result);
2765 return result;
2766 }
2767
2768 Value GetTensorAttribute(const char* name, OrtAllocator* allocator) const;
2769
2770 size_t GetInputCount() const;
2771 size_t GetOutputCount() const;
2772
2773 std::string GetInputName(size_t index) const;
2774 std::string GetOutputName(size_t index) const;
2775
2776 TypeInfo GetInputTypeInfo(size_t index) const;
2777 TypeInfo GetOutputTypeInfo(size_t index) const;
2778
2779 ConstValue GetTensorConstantInput(size_t index, int* is_constant) const;
2780
2781 std::string GetNodeName() const;
2782 Logger GetLogger() const;
2783
2784 KeyValuePairs GetConfigEntries() const;
2785
2786 std::string GetOperatorDomain() const;
2787 std::string GetOperatorType() const;
2788 int GetOperatorSinceVersion() const;
2789 const OrtEp* GetEp() const;
2790};
2791
2792} // namespace detail
2793
2794using ConstKernelInfo = detail::KernelInfoImpl<detail::Unowned<const OrtKernelInfo>>;
2795
2802struct KernelInfo : detail::KernelInfoImpl<OrtKernelInfo> {
2803 using Base = detail::KernelInfoImpl<OrtKernelInfo>;
2804 using Base::Base;
2805 explicit KernelInfo(std::nullptr_t) {}
2806 explicit KernelInfo(OrtKernelInfo* info);
2807 ConstKernelInfo GetConst() const { return ConstKernelInfo{this->p_}; }
2808};
2809
2813struct Op : detail::Base<OrtOp> {
2815 using Base::Base;
2816
2817 explicit Op(std::nullptr_t) {}
2818
2819 explicit Op(OrtOp*);
2820
2821 static Op Create(const OrtKernelInfo* info, const char* op_name, const char* domain,
2822 int version, const char** type_constraint_names,
2823 const ONNXTensorElementDataType* type_constraint_values,
2824 size_t type_constraint_count,
2825 const OpAttr* attr_values,
2826 size_t attr_count,
2827 size_t input_count, size_t output_count);
2828
2829 void Invoke(const OrtKernelContext* context,
2830 const Value* input_values,
2831 size_t input_count,
2832 Value* output_values,
2833 size_t output_count);
2834
2835 // For easier refactoring
2836 void Invoke(const OrtKernelContext* context,
2837 const OrtValue* const* input_values,
2838 size_t input_count,
2839 OrtValue* const* output_values,
2840 size_t output_count);
2841};
2842
2848 SymbolicInteger(int64_t i) : i_(i), is_int_(true) {};
2849 SymbolicInteger(const char* s) : s_(s), is_int_(false) {};
2852
2855
2856 bool operator==(const SymbolicInteger& dim) const {
2857 if (is_int_ == dim.is_int_) {
2858 if (is_int_) {
2859 return i_ == dim.i_;
2860 } else {
2861 return std::string{s_} == std::string{dim.s_};
2862 }
2863 }
2864 return false;
2865 }
2866
2867 bool IsInt() const { return is_int_; }
2868 int64_t AsInt() const { return i_; }
2869 const char* AsSym() const { return s_; }
2870
2871 static constexpr int INVALID_INT_DIM = -2;
2872
2873 private:
2874 union {
2875 int64_t i_;
2876 const char* s_;
2877 };
2878 bool is_int_;
2879 };
2880
2881 using Shape = std::vector<SymbolicInteger>;
2882
2884
2885 const Shape& GetInputShape(size_t indice) const { return input_shapes_.at(indice); }
2886
2887 size_t GetInputCount() const { return input_shapes_.size(); }
2888
2890
2891 int64_t GetAttrInt(const char* attr_name);
2892
2893 using Ints = std::vector<int64_t>;
2894 Ints GetAttrInts(const char* attr_name);
2895
2896 float GetAttrFloat(const char* attr_name);
2897
2898 using Floats = std::vector<float>;
2899 Floats GetAttrFloats(const char* attr_name);
2900
2901 std::string GetAttrString(const char* attr_name);
2902
2903 using Strings = std::vector<std::string>;
2904 Strings GetAttrStrings(const char* attr_name);
2905
2906 private:
2907 ConstOpAttr GetAttrHdl(const char* attr_name) const;
2908 const OrtApi* ort_api_;
2910 std::vector<Shape> input_shapes_;
2911};
2912
2914
2915#define MAX_CUSTOM_OP_END_VER (1UL << 31) - 1
2916
2917template <typename TOp, typename TKernel, bool WithStatus = false>
2921 OrtCustomOp::GetName = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetName(); };
2922
2923 OrtCustomOp::GetExecutionProviderType = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetExecutionProviderType(); };
2924
2925 OrtCustomOp::GetInputTypeCount = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetInputTypeCount(); };
2926 OrtCustomOp::GetInputType = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetInputType(index); };
2927 OrtCustomOp::GetInputMemoryType = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetInputMemoryType(index); };
2928
2929 OrtCustomOp::GetOutputTypeCount = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetOutputTypeCount(); };
2930 OrtCustomOp::GetOutputType = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetOutputType(index); };
2931
2932#if defined(_MSC_VER) && !defined(__clang__)
2933#pragma warning(push)
2934#pragma warning(disable : 26409)
2935#endif
2936 OrtCustomOp::KernelDestroy = [](void* op_kernel) { delete static_cast<TKernel*>(op_kernel); };
2937#if defined(_MSC_VER) && !defined(__clang__)
2938#pragma warning(pop)
2939#endif
2940 OrtCustomOp::GetInputCharacteristic = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetInputCharacteristic(index); };
2941 OrtCustomOp::GetOutputCharacteristic = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetOutputCharacteristic(index); };
2942
2943 OrtCustomOp::GetVariadicInputMinArity = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetVariadicInputMinArity(); };
2944 OrtCustomOp::GetVariadicInputHomogeneity = [](const OrtCustomOp* this_) { return static_cast<int>(static_cast<const TOp*>(this_)->GetVariadicInputHomogeneity()); };
2945 OrtCustomOp::GetVariadicOutputMinArity = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetVariadicOutputMinArity(); };
2946 OrtCustomOp::GetVariadicOutputHomogeneity = [](const OrtCustomOp* this_) { return static_cast<int>(static_cast<const TOp*>(this_)->GetVariadicOutputHomogeneity()); };
2947#ifdef __cpp_if_constexpr
2948 if constexpr (WithStatus) {
2949#else
2950 if (WithStatus) {
2951#endif
2952 OrtCustomOp::CreateKernelV2 = [](const OrtCustomOp* this_, const OrtApi* api, const OrtKernelInfo* info, void** op_kernel) -> OrtStatusPtr {
2953 return static_cast<const TOp*>(this_)->CreateKernelV2(*api, info, op_kernel);
2954 };
2955 OrtCustomOp::KernelComputeV2 = [](void* op_kernel, OrtKernelContext* context) -> OrtStatusPtr {
2956 return static_cast<TKernel*>(op_kernel)->ComputeV2(context);
2957 };
2958 } else {
2961
2962 OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* api, const OrtKernelInfo* info) { return static_cast<const TOp*>(this_)->CreateKernel(*api, info); };
2963 OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) {
2964 static_cast<TKernel*>(op_kernel)->Compute(context);
2965 };
2966 }
2967
2968 SetShapeInferFn<TOp>(0);
2969
2970 OrtCustomOp::GetStartVersion = [](const OrtCustomOp* this_) {
2971 return static_cast<const TOp*>(this_)->start_ver_;
2972 };
2973
2974 OrtCustomOp::GetEndVersion = [](const OrtCustomOp* this_) {
2975 return static_cast<const TOp*>(this_)->end_ver_;
2976 };
2977
2980 OrtCustomOp::GetAliasMap = nullptr;
2982 }
2983
2984 // Default implementation of GetExecutionProviderType that returns nullptr to default to the CPU provider
2985 const char* GetExecutionProviderType() const { return nullptr; }
2986
2987 // Default implementations of GetInputCharacteristic() and GetOutputCharacteristic() below
2988 // (inputs and outputs are required by default)
2990 return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED;
2991 }
2992
2994 return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED;
2995 }
2996
2997 // Default implementation of GetInputMemoryType() that returns OrtMemTypeDefault
2998 OrtMemType GetInputMemoryType(size_t /*index*/) const {
2999 return OrtMemTypeDefault;
3000 }
3001
3002 // Default implementation of GetVariadicInputMinArity() returns 1 to specify that a variadic input
3003 // should expect at least 1 argument.
3005 return 1;
3006 }
3007
3008 // Default implementation of GetVariadicInputHomegeneity() returns true to specify that all arguments
3009 // to a variadic input should be of the same type.
3011 return true;
3012 }
3013
3014 // Default implementation of GetVariadicOutputMinArity() returns 1 to specify that a variadic output
3015 // should produce at least 1 output value.
3017 return 1;
3018 }
3019
3020 // Default implementation of GetVariadicOutputHomegeneity() returns true to specify that all output values
3021 // produced by a variadic output should be of the same type.
3023 return true;
3024 }
3025
3026 // Declare list of session config entries used by this Custom Op.
3027 // Implement this function in order to get configs from CustomOpBase::GetSessionConfigs().
3028 // This default implementation returns an empty vector of config entries.
3029 std::vector<std::string> GetSessionConfigKeys() const {
3030 return std::vector<std::string>{};
3031 }
3032
3033 // Ort::CustomOpBase derived class should provide the following static method with the type/shape inferencing
3034 // implementation if needed:
3035 // static OrtStatusPtr InferOutputShape(Ort::ShapeInferContext& context)
3036 template <typename C>
3037 decltype(&C::InferOutputShape) SetShapeInferFn(decltype(&C::InferOutputShape)) {
3039 ShapeInferContext ctx(&GetApi(), ort_ctx);
3040 return C::InferOutputShape(ctx);
3041 };
3042 return {};
3043 }
3044
3045 template <typename C>
3049
3050 protected:
3051 // Helper function that returns a map of session config entries specified by CustomOpBase::GetSessionConfigKeys.
3052 void GetSessionConfigs(std::unordered_map<std::string, std::string>& out, ConstSessionOptions options) const;
3053
3054 int start_ver_ = 1;
3055 int end_ver_ = MAX_CUSTOM_OP_END_VER;
3056};
3057
3058// Forward declaration to resolve circular dependency
3059// on ConstNode
3061
3062namespace detail {
3063template <typename T>
3065 using B = Base<T>;
3066 using B::B;
3067
3069 std::string GetName() const;
3075 std::vector<ValueInfoConsumerProducerInfo> GetConsumers() const;
3085 bool IsGraphOutput() const;
3089 bool IsFromOuterScope() const;
3090};
3091} // namespace detail
3092
3093// Const object holder that does not own the underlying object
3095
3100 ValueInfo() = default; // Same thing as with nullptr
3101 explicit ValueInfo(std::nullptr_t) {}
3103 explicit ValueInfo(OrtValueInfo* p) : ConstValueInfoImpl<OrtValueInfo>{p} {}
3104
3105#if !defined(ORT_MINIMAL_BUILD)
3106 // Create ValueInfo for a tensor
3107 explicit ValueInfo(const std::string& name, const ConstTypeInfo& type_info);
3108#endif
3109 ConstValueInfo GetConst() const { return ConstValueInfo{this->p_}; }
3110};
3111
3112// Forward declaration
3113struct AttrNameSubgraph;
3114
3115namespace detail {
3116// Forward decl
3117template <typename T>
3118struct ConstGraphImpl;
3119
3120template <typename T>
3121struct ConstNodeImpl : Base<T> {
3122 using B = Base<T>;
3123 using B::B;
3124
3125 // <Wraps OrtApi::Node_GetId
3126 size_t GetId() const;
3127 // <Wraps OrtApi::Node_GetName
3128 std::string GetName() const;
3129 // <Wraps OrtApi::Node_GetOperatorType
3130 std::string GetOperatorType() const;
3131 // <Wraps OrtApi::Node_GetDomain
3132 std::string GetDomain() const;
3133 // <Wraps OrtApi::Node_GetSinceVersion
3134 int GetSinceVersion() const;
3135
3136 // <Wraps OrtApi::Node_Inputs
3137 std::vector<ConstValueInfo> GetInputs() const;
3138 // <Wraps OrtApi::Node_Outputs
3139 std::vector<ConstValueInfo> GetOutputs() const;
3140 // <Wraps OrtApi::Node_ImplicitInputs
3141 std::vector<ConstValueInfo> GetImplicitInputs() const;
3142 // <Wraps OrtApi::Node_GetAttributes
3143 std::vector<ConstOpAttr> GetAttributes() const;
3144 // <Wraps OrtApi::Node_GetAttributeByName
3145 // Please, read C API doc for details
3146 Status GetAttributeByName(const std::string& name, ConstOpAttr& attr) const;
3147 // <Wraps OrtApi::Node_GetSubgraphs
3148 std::vector<AttrNameSubgraph> GetSubgraphs() const;
3149 // <Wraps OrtApi::Node_GetGraph
3150 // ConstGraph is not available yet
3152 // <Wraps OrtApi::Node_GetEpName
3153 std::string GetEpName() const;
3154};
3155} // namespace detail
3156
3158
3162struct Node : detail::ConstNodeImpl<OrtNode> {
3163 Node() = default; // Same thing as with nullptr
3164 explicit Node(std::nullptr_t) {}
3165 explicit Node(OrtNode* p) : ConstNodeImpl<OrtNode>{p} {}
3166
3167#if !defined(ORT_MINIMAL_BUILD)
3168 Node(const std::string& operator_name, const std::string& operator_domain,
3169 const std::string& node_name,
3170 const std::vector<std::string>& input_names,
3171 const std::vector<std::string>& output_names);
3172
3176 Node(const std::string& operator_name, const std::string& operator_domain,
3177 const std::string& node_name,
3178 const std::vector<std::string>& input_names,
3179 const std::vector<std::string>& output_names,
3180 std::vector<OpAttr>& attributes);
3181
3182 private:
3183 static void Init(const std::string& operator_name, const std::string& operator_domain,
3184 const std::string& node_name,
3185 const std::vector<std::string>& input_names,
3186 const std::vector<std::string>& output_names,
3187 std::vector<OpAttr>& attributes,
3188 OrtNode*& node);
3189#endif // !defined(ORT_MINIMAL_BUILD)
3190};
3191
3192// Return struct for some of ValueInfo APIs.
3193// Must be declared after ConstNode is available.
3196 // either producer output or consumer output index
3197 // producer is unsigned only, output can be -1
3198 int64_t index;
3199};
3200
3201// Represents a return value for Graph::GetOperatorSets()
3203 std::string domain;
3204 int64_t version;
3205};
3206
3207namespace detail {
3208template <typename T>
3210 using B = Base<T>;
3211 using B::B;
3212
3213 // <Wraps OrtApi::Graph_GetName
3214 std::string GetName() const;
3215 // <Wraps OrtApi::Graph_GetModelPath
3216 std::basic_string<ORTCHAR_T> GetModelPath() const;
3217 // <Wraps OrtApi::Graph_GetOnnxIRVersion
3218 int64_t GetOnnxIRVersion() const;
3219 // <Wraps OrtApi::Graph_GetOperatorSets
3220 std::vector<OperatorSet> GetOperatorSets() const;
3221 // <Wraps OrtApi::Graph_Inputs
3222 std::vector<ConstValueInfo> GetInputs() const;
3223 // <Wraps OrtApi::Graph_Outputs
3224 std::vector<ConstValueInfo> GetOutputs() const;
3225 // <Wraps OrtApi::Graph_Initializers
3226 std::vector<ConstValueInfo> GetInitializers() const;
3227 // <Wraps OrtApi::Graph_GetNodes
3228 std::vector<ConstNode> GetNodes() const;
3229 // <Wraps OrtApi::Graph_GetParentGraph
3231 // <Wraps OrtApi::Graph_GetGraphView
3232 Graph GetGraphView(const std::vector<ConstNode>& nodes) const;
3233 // <Wraps OrtApi::Graph_GetModelMetadata
3235};
3236
3237template <typename T>
3240 using B::B;
3241
3242#if !defined(ORT_MINIMAL_BUILD)
3243 // <Wraps GetModelEditorApi().SetGraphInputs()
3244 void SetInputs(std::vector<ValueInfo>& inputs);
3245 // <Wraps GetModelEditorApi().SetGraphOutputs()
3246 void SetOutputs(std::vector<ValueInfo>& outputs);
3247 // <Wraps GetModelEditorApi().AddInitializerToGraph()
3248 void AddInitializer(const std::string& name, Value& initializer, bool data_is_external); // Graph takes ownership of Value
3249 // <Wraps GetModelEditorApi().AddNodeToGraph()
3250 void AddNode(Node& node); // Graph takes ownership of Node
3251#endif // !defined(ORT_MINIMAL_BUILD)
3252};
3253} // namespace detail
3254
3256
3257// Return value for Node API
3258// Must be declared after ConstGraph
3263
3267struct Graph : detail::GraphImpl<OrtGraph> {
3268 explicit Graph(std::nullptr_t) {}
3269 explicit Graph(OrtGraph* p) : GraphImpl<OrtGraph>{p} {}
3270#if !defined(ORT_MINIMAL_BUILD)
3271 // <Wraps GetModelEditorApi().CreateGraph()
3273#endif
3274};
3275
3276namespace detail {
3277template <typename T>
3280 using B::B;
3281
3282#if !defined(ORT_MINIMAL_BUILD)
3283 // <Wraps GetModelEditorApi().AddGraphToModel()
3284 void AddGraph(Graph& graph);
3285#endif
3286};
3287} // namespace detail
3288
3289// Const object holder that does not own the underlying object
3291
3295struct Model : detail::ModelImpl<OrtModel> {
3296 using DomainOpsetPair = std::pair<std::string, int>;
3297
3298 explicit Model(std::nullptr_t) {}
3299 explicit Model(OrtModel* p) : ModelImpl<OrtModel>{p} {}
3300
3301#if !defined(ORT_MINIMAL_BUILD)
3302 //< Wraps GetModelEditorApi().CreateModel()
3303 explicit Model(const std::vector<DomainOpsetPair>& opsets);
3304#endif
3305};
3306
3307namespace detail {
3308template <typename T>
3310 using B = Base<T>;
3311 using B::B;
3312
3314 const char* GetOperatorType() const;
3315
3317 const char* GetDomain() const;
3318
3320 std::pair<int, int> GetSinceVersion() const;
3321
3323 const char* GetExecutionProvider() const;
3324
3326 OrtMemType GetInputMemType(size_t input_index) const;
3327
3329 OrtMemType GetOutputMemType(size_t output_index) const;
3330};
3331} // namespace detail
3332
3334
3337 using Base::Base;
3338
3339 explicit KernelDef(std::nullptr_t) {}
3340 explicit KernelDef(OrtKernelDef* p) : detail::ConstKernelDefImpl<OrtKernelDef>{p} {}
3341
3342 ConstKernelDef GetConst() const { return ConstKernelDef{this->p_}; }
3343};
3344
3349struct KernelDefBuilder : detail::Base<OrtKernelDefBuilder> {
3351 explicit KernelDefBuilder(std::nullptr_t) {}
3352 explicit KernelDefBuilder(OrtKernelDefBuilder* ort_kernel_def_builder);
3353
3354 KernelDefBuilder& SetOperatorType(const char* op_type);
3355 KernelDefBuilder& SetDomain(const char* domain);
3356 KernelDefBuilder& SetSinceVersion(int since_version_start, int since_version_end);
3358 KernelDefBuilder& SetInputMemType(size_t input_index, OrtMemType mem_type);
3359 KernelDefBuilder& SetOutputMemType(size_t output_index, OrtMemType mem_type);
3360 KernelDefBuilder& AddTypeConstraint(const char* arg_name, const OrtDataType* data_type);
3361 KernelDefBuilder& AddTypeConstraint(const char* arg_name, const std::vector<const OrtDataType*>& data_types);
3362 KernelDefBuilder& AddInputOutputAlias(int input_index, int output_index);
3363 KernelDefBuilder& AddInputOutputAliases(const std::vector<int>& input_indices,
3364 const std::vector<int>& output_indices);
3365 KernelDefBuilder& AddInputOutputMutableAlias(int input_index, int output_index);
3366 KernelDefBuilder& AddInputOutputMutableAliases(const std::vector<int>& input_indices,
3367 const std::vector<int>& output_indices);
3368
3370};
3371
3376struct KernelRegistry : detail::Base<OrtKernelRegistry> {
3379
3381 explicit KernelRegistry(std::nullptr_t) {}
3382
3384 explicit KernelRegistry(OrtKernelRegistry* ort_kernel_registry);
3385
3387 Status AddKernel(const OrtKernelDef* kernel_def, OrtKernelCreateFunc kernel_create_func,
3388 void* kernel_create_func_state);
3389};
3390
3391namespace detail {
3392template <typename T>
3395 using B::B;
3396
3397 //< Wraps SharedPrePackedWeightCache_StoreWeightData
3398 Status StoreWeightData(void** buffer_data_ptrs, size_t* buffer_sizes, size_t num_buffers);
3399};
3400} // namespace detail
3401
3419} // namespace Ort
3420#include "onnxruntime_cxx_inline.h"
struct OrtMemoryInfo OrtMemoryInfo
Definition onnxruntime_c_api.h:296
struct OrtKernelInfo OrtKernelInfo
Definition onnxruntime_c_api.h:450
struct OrtNode OrtNode
Definition onnxruntime_c_api.h:324
OrtLoggingLevel
Logging severity levels.
Definition onnxruntime_c_api.h:246
OrtMemoryInfoDeviceType
This mimics OrtDevice type constants so they can be returned in the API.
Definition onnxruntime_c_api.h:485
struct OrtShapeInferContext OrtShapeInferContext
Definition onnxruntime_c_api.h:321
void(* OrtLoggingFunction)(void *param, OrtLoggingLevel severity, const char *category, const char *logid, const char *code_location, const char *message)
Definition onnxruntime_c_api.h:414
void(* OrtCustomJoinThreadFn)(OrtCustomThreadHandle ort_custom_thread_handle)
Custom thread join function.
Definition onnxruntime_c_api.h:938
OrtCustomOpInputOutputCharacteristic
Definition onnxruntime_c_api.h:6687
struct OrtTensorRTProviderOptionsV2 OrtTensorRTProviderOptionsV2
Definition onnxruntime_c_api.h:313
struct OrtThreadingOptions OrtThreadingOptions
Definition onnxruntime_c_api.h:310
struct OrtSequenceTypeInfo OrtSequenceTypeInfo
Definition onnxruntime_c_api.h:304
struct OrtValueInfo OrtValueInfo
Definition onnxruntime_c_api.h:323
struct OrtDnnlProviderOptions OrtDnnlProviderOptions
Definition onnxruntime_c_api.h:317
OrtSparseIndicesFormat
Definition onnxruntime_c_api.h:235
struct OrtPrepackedWeightsContainer OrtPrepackedWeightsContainer
Definition onnxruntime_c_api.h:312
struct OrtSession OrtSession
Definition onnxruntime_c_api.h:298
OrtCompiledModelCompatibility
The C API.
Definition onnxruntime_c_api.h:961
OrtStatus *(* EpSelectionDelegate)(const OrtEpDevice **ep_devices, size_t num_devices, const OrtKeyValuePairs *model_metadata, const OrtKeyValuePairs *runtime_metadata, const OrtEpDevice **selected, size_t max_selected, size_t *num_selected, void *state)
Delegate to allow providing custom OrtEpDevice selection logic.
Definition onnxruntime_c_api.h:529
struct OrtCustomOpDomain OrtCustomOpDomain
Definition onnxruntime_c_api.h:307
struct OrtIoBinding OrtIoBinding
Definition onnxruntime_c_api.h:297
struct OrtExternalInitializerInfo OrtExternalInitializerInfo
Definition onnxruntime_c_api.h:332
OrtAllocatorType
Definition onnxruntime_c_api.h:456
struct OrtOp OrtOp
Definition onnxruntime_c_api.h:318
struct OrtTypeInfo OrtTypeInfo
Definition onnxruntime_c_api.h:301
struct OrtTensorTypeAndShapeInfo OrtTensorTypeAndShapeInfo
Definition onnxruntime_c_api.h:302
struct OrtCUDAProviderOptionsV2 OrtCUDAProviderOptionsV2
Definition onnxruntime_c_api.h:315
struct OrtKernelContext OrtKernelContext
Definition onnxruntime_c_api.h:452
struct OrtCANNProviderOptions OrtCANNProviderOptions
Definition onnxruntime_c_api.h:316
struct OrtEpDevice OrtEpDevice
Definition onnxruntime_c_api.h:329
void(* RunAsyncCallbackFn)(void *user_data, OrtValue **outputs, size_t num_outputs, OrtStatusPtr status)
Callback function for RunAsync.
Definition onnxruntime_c_api.h:949
OrtHardwareDeviceType
Definition onnxruntime_c_api.h:492
struct OrtModel OrtModel
Definition onnxruntime_c_api.h:326
struct OrtGraph OrtGraph
Definition onnxruntime_c_api.h:325
struct OrtSyncStream OrtSyncStream
Definition onnxruntime_c_api.h:331
struct OrtSessionOptions OrtSessionOptions
Definition onnxruntime_c_api.h:306
OrtDeviceMemoryType
This matches OrtDevice::MemoryType values.
Definition onnxruntime_c_api.h:478
struct OrtValue OrtValue
Definition onnxruntime_c_api.h:299
OrtStatus *(* OrtWriteBufferFunc)(void *state, const void *buffer, size_t buffer_num_bytes)
Function called by ORT to write a buffer to a custom destination (e.g., file, stream,...
Definition onnxruntime_c_api.h:548
GraphOptimizationLevel
Graph optimization level.
Definition onnxruntime_c_api.h:423
struct OrtKeyValuePairs OrtKeyValuePairs
Definition onnxruntime_c_api.h:330
OrtStatus * OrtStatusPtr
Definition onnxruntime_c_api.h:337
OrtMemType
Memory types for allocated memory, execution provider specific types should be extended in each provi...
Definition onnxruntime_c_api.h:466
OrtSparseFormat
Definition onnxruntime_c_api.h:227
ONNXType
Definition onnxruntime_c_api.h:215
struct OrtEnv OrtEnv
Definition onnxruntime_c_api.h:294
OrtErrorCode
Definition onnxruntime_c_api.h:254
struct OrtStatus OrtStatus
Definition onnxruntime_c_api.h:295
OrtStatus *(* OrtGetInitializerLocationFunc)(void *state, const char *initializer_name, const OrtValue *initializer_value, const OrtExternalInitializerInfo *external_info, OrtExternalInitializerInfo **new_external_info)
Function called by ORT to allow user to specify how an initializer should be saved,...
Definition onnxruntime_c_api.h:582
#define ORT_API_VERSION
The API version defined in this header.
Definition onnxruntime_c_api.h:41
struct OrtLogger OrtLogger
Definition onnxruntime_c_api.h:320
struct OrtMapTypeInfo OrtMapTypeInfo
Definition onnxruntime_c_api.h:303
struct OrtArenaCfg OrtArenaCfg
Definition onnxruntime_c_api.h:311
ExecutionMode
Definition onnxruntime_c_api.h:431
OrtOpAttrType
Definition onnxruntime_c_api.h:272
OrtCustomThreadHandle(* OrtCustomCreateThreadFn)(void *ort_custom_thread_creation_options, OrtThreadWorkerFn ort_thread_worker_fn, void *ort_worker_fn_param)
Ort custom thread creation function.
Definition onnxruntime_c_api.h:931
ONNXTensorElementDataType
Definition onnxruntime_c_api.h:184
OrtExecutionProviderDevicePolicy
These are the default EP selection policies used by ORT when doing automatic EP selection.
Definition onnxruntime_c_api.h:500
const OrtApiBase * OrtGetApiBase(void)
The Onnxruntime library's entry point to access the C API.
@ ORT_LOGGING_LEVEL_WARNING
Warning messages.
Definition onnxruntime_c_api.h:249
@ OrtMemTypeDefault
The default allocator for execution provider.
Definition onnxruntime_c_api.h:474
@ ORT_FAIL
Definition onnxruntime_c_api.h:256
@ ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT
Definition onnxruntime_c_api.h:186
std::vector< Value > GetOutputValuesHelper(const OrtIoBinding *binding, OrtAllocator *)
std::vector< std::string > GetOutputNamesHelper(const OrtIoBinding *binding, OrtAllocator *)
void OrtRelease(OrtAllocator *ptr)
Definition onnxruntime_cxx_api.h:615
std::string MakeCustomOpConfigEntryKey(const char *custom_op_name, const char *config)
All C++ Onnxruntime APIs are defined inside this namespace.
Definition onnxruntime_cxx_api.h:48
const OrtModelEditorApi & GetModelEditorApi()
This returns a reference to the ORT C Model Editor API. Used if building or augmenting a model at run...
Definition onnxruntime_cxx_api.h:215
std::unique_ptr< char, detail::AllocatedFree > AllocatedStringPtr
unique_ptr typedef used to own strings allocated by OrtAllocators and release them at the end of the ...
Definition onnxruntime_cxx_api.h:791
detail::ConstSessionOptionsImpl< detail::Unowned< const OrtSessionOptions > > ConstSessionOptions
Definition onnxruntime_cxx_api.h:1460
detail::KernelInfoImpl< detail::Unowned< const OrtKernelInfo > > ConstKernelInfo
Definition onnxruntime_cxx_api.h:2794
const OrtApi & GetApi() noexcept
This returns a reference to the ORT C API.
Definition onnxruntime_cxx_api.h:189
const OrtCompileApi & GetCompileApi()
This returns a reference to the ORT C Compile API. Used if compiling a model at runtime.
Definition onnxruntime_cxx_api.h:229
detail::AllocatorImpl< detail::Unowned< OrtAllocator > > UnownedAllocator
Definition onnxruntime_cxx_api.h:1067
OrtCompiledModelCompatibility GetModelCompatibilityForEpDevices(const std::vector< ConstEpDevice > &ep_devices, const char *compatibility_info)
Validate a compiled model's compatibility for one or more EP devices.
detail::SessionOptionsImpl< detail::Unowned< OrtSessionOptions > > UnownedSessionOptions
Definition onnxruntime_cxx_api.h:1459
std::string GetBuildInfoString()
This function returns the onnxruntime build information: including git branch, git commit id,...
const OrtEpApi & GetEpApi()
This returns a reference to the ORT C EP API. Used if authoring a plugin execution provider.
Definition onnxruntime_cxx_api.h:243
std::string GetVersionString()
This function returns the onnxruntime version string.
std::vector< std::string > GetAvailableProviders()
This is a C++ wrapper for OrtApi::GetAvailableProviders() and returns a vector of strings representin...
Ort::Status(*)(Ort::ShapeInferContext &) ShapeInferFn
Definition onnxruntime_cxx_api.h:2913
Status CompileModel(const Env &env, const ModelCompilationOptions &model_compilation_options)
Compiles an input model to generate a model with EPContext nodes that execute EP-specific kernels....
Wrapper around OrtAllocator.
Definition onnxruntime_cxx_api.h:1059
Allocator(const Session &session, const OrtMemoryInfo *)
Take ownership of a pointer created by C API.
Allocator(std::nullptr_t)
Convenience to create a class member and then replace with an instance.
Definition onnxruntime_cxx_api.h:1060
Allocator(OrtAllocator *p)
Definition onnxruntime_cxx_api.h:1064
Wrapper around OrtAllocator default instance that is owned by Onnxruntime.
Definition onnxruntime_cxx_api.h:1050
AllocatorWithDefaultOptions(std::nullptr_t)
Convenience to create a class member and then replace with an instance.
Definition onnxruntime_cxx_api.h:1051
it is a structure that represents the configuration of an arena based allocator
Definition onnxruntime_cxx_api.h:2492
ArenaCfg(std::nullptr_t)
Create an empty ArenaCfg object, must be assigned a valid one to be used.
Definition onnxruntime_cxx_api.h:2493
ArenaCfg(const std::unordered_map< std::string, size_t > &arena_config)
ArenaCfg(size_t max_mem, int arena_extend_strategy, int initial_chunk_size_bytes, int max_dead_bytes_per_chunk)
Definition onnxruntime_cxx_api.h:3259
ConstGraph sub_graph
Definition onnxruntime_cxx_api.h:3261
std::string attr_name
Definition onnxruntime_cxx_api.h:3260
bfloat16 (Brain Floating Point) data type
Definition onnxruntime_cxx_api.h:413
bool operator==(const BFloat16_t &rhs) const noexcept
onnxruntime_float16::BFloat16Impl< BFloat16_t > Base
Definition onnxruntime_cxx_api.h:425
BFloat16_t()=default
static constexpr BFloat16_t FromBits(uint16_t v) noexcept
Explicit conversion to uint16_t representation of bfloat16.
Definition onnxruntime_cxx_api.h:434
bool operator!=(const BFloat16_t &rhs) const noexcept
Definition onnxruntime_cxx_api.h:532
BFloat16_t(float v) noexcept
__ctor from float. Float is converted into bfloat16 16-bit representation.
Definition onnxruntime_cxx_api.h:440
float ToFloat() const noexcept
Converts bfloat16 to float.
Definition onnxruntime_cxx_api.h:446
bool operator<(const BFloat16_t &rhs) const noexcept
The CUDAProviderOptions (V2)
Definition onnxruntime_cxx_api.h:862
CUDAProviderOptions()
Wraps OrtApi::CreateCUDAProviderOptions.
CUDAProviderOptions(std::nullptr_t)
Definition onnxruntime_cxx_api.h:863
void UpdateWithValue(const char *key, void *value)
Wrapper around OrtApi::GetCUDAProviderOptionsByName.
std::string GetCUDAProviderOptionsAsString() const
Wrapper around OrtApi::UpdateCUDAProviderOptionsWithValue.
void Update(const std::unordered_map< std::string, std::string > &options)
Wrapper around OrtApi::GetCUDAProviderOptionsAsString.
void * GetOptionByName(const char *name) const
Definition onnxruntime_cxx_api.h:2918
OrtCustomOpInputOutputCharacteristic GetOutputCharacteristic(size_t) const
Definition onnxruntime_cxx_api.h:2993
OrtCustomOpInputOutputCharacteristic GetInputCharacteristic(size_t) const
Definition onnxruntime_cxx_api.h:2989
OrtMemType GetInputMemoryType(size_t) const
Definition onnxruntime_cxx_api.h:2998
std::vector< std::string > GetSessionConfigKeys() const
Definition onnxruntime_cxx_api.h:3029
bool GetVariadicInputHomogeneity() const
Definition onnxruntime_cxx_api.h:3010
int GetVariadicInputMinArity() const
Definition onnxruntime_cxx_api.h:3004
void SetShapeInferFn(...)
Definition onnxruntime_cxx_api.h:3046
CustomOpBase()
Definition onnxruntime_cxx_api.h:2919
bool GetVariadicOutputHomogeneity() const
Definition onnxruntime_cxx_api.h:3022
int GetVariadicOutputMinArity() const
Definition onnxruntime_cxx_api.h:3016
decltype(&C::InferOutputShape) SetShapeInferFn(decltype(&C::InferOutputShape))
Definition onnxruntime_cxx_api.h:3037
const char * GetExecutionProviderType() const
Definition onnxruntime_cxx_api.h:2985
void GetSessionConfigs(std::unordered_map< std::string, std::string > &out, ConstSessionOptions options) const
Class that represents session configuration entries for one or more custom operators.
Definition onnxruntime_cxx_api.h:1310
~CustomOpConfigs()=default
CustomOpConfigs & AddConfig(const char *custom_op_name, const char *config_key, const char *config_value)
Adds a session configuration entry/value for a specific custom operator.
CustomOpConfigs & operator=(CustomOpConfigs &&o)=default
CustomOpConfigs(CustomOpConfigs &&o)=default
CustomOpConfigs()=default
const std::unordered_map< std::string, std::string > & GetFlattenedConfigs() const
Returns a flattened map of custom operator configuration entries and their values.
CustomOpConfigs(const CustomOpConfigs &)=default
CustomOpConfigs & operator=(const CustomOpConfigs &)=default
Custom Op Domain.
Definition onnxruntime_cxx_api.h:1215
CustomOpDomain(std::nullptr_t)
Create an empty CustomOpDomain object, must be assigned a valid one to be used.
Definition onnxruntime_cxx_api.h:1219
CustomOpDomain(const char *domain)
Wraps OrtApi::CreateCustomOpDomain.
void Add(const OrtCustomOp *op)
Wraps CustomOpDomain_Add.
The Env (Environment)
Definition onnxruntime_cxx_api.h:1158
Env & EnableTelemetryEvents()
Wraps OrtApi::EnableTelemetryEvents.
Env(OrtEnv *p)
C Interop Helper.
Definition onnxruntime_cxx_api.h:1175
Env & CreateAndRegisterAllocatorV2(const std::string &provider_type, const OrtMemoryInfo *mem_info, const std::unordered_map< std::string, std::string > &options, const OrtArenaCfg *arena_cfg)
Wraps OrtApi::CreateAndRegisterAllocatorV2.
Env & UnregisterExecutionProviderLibrary(const char *registration_name)
Wraps OrtApi::UnregisterExecutionProviderLibrary.
std::vector< ConstEpDevice > GetEpDevices() const
Env & UnregisterAllocator(const OrtMemoryInfo *mem_info)
Wraps OrtApi::UnregisterAllocator.
Env(std::nullptr_t)
Create an empty Env object, must be assigned a valid one to be used.
Definition onnxruntime_cxx_api.h:1159
Env(OrtLoggingLevel logging_level=ORT_LOGGING_LEVEL_WARNING, const char *logid="")
Wraps OrtApi::CreateEnv.
Env(const OrtThreadingOptions *tp_options, OrtLoggingLevel logging_level=ORT_LOGGING_LEVEL_WARNING, const char *logid="")
Wraps OrtApi::CreateEnvWithGlobalThreadPools.
Env(const OrtThreadingOptions *tp_options, OrtLoggingFunction logging_function, void *logger_param, OrtLoggingLevel logging_level=ORT_LOGGING_LEVEL_WARNING, const char *logid="")
Wraps OrtApi::CreateEnvWithCustomLoggerAndGlobalThreadPools.
Env & RegisterAllocator(OrtAllocator *allocator)
Wraps OrtApi::RegisterAllocator.
UnownedAllocator CreateSharedAllocator(const OrtEpDevice *ep_device, OrtDeviceMemoryType mem_type, OrtAllocatorType allocator_type, const OrtKeyValuePairs *allocator_options)
Wraps OrtApi::CreateSharedAllocator.
Env(OrtLoggingLevel logging_level, const char *logid, OrtLoggingFunction logging_function, void *logger_param)
Wraps OrtApi::CreateEnvWithCustomLogger.
Env & CreateAndRegisterAllocator(const OrtMemoryInfo *mem_info, const OrtArenaCfg *arena_cfg)
Wraps OrtApi::CreateAndRegisterAllocator.
UnownedAllocator GetSharedAllocator(const OrtMemoryInfo *mem_info)
Wraps OrtApi::GetSharedAllocator.
Env & RegisterExecutionProviderLibrary(const char *registration_name, const std::basic_string< char > &path)
Wraps OrtApi::RegisterExecutionProviderLibrary.
Env & UpdateEnvWithCustomLogLevel(OrtLoggingLevel log_severity_level)
Wraps OrtApi::UpdateEnvWithCustomLogLevel.
Status CopyTensors(const std::vector< Value > &src_tensors, const std::vector< Value > &dst_tensors, OrtSyncStream *stream) const
Wraps OrtApi::CopyTensors.
void ReleaseSharedAllocator(const OrtEpDevice *ep_device, OrtDeviceMemoryType mem_type)
Wraps OrtApi::ReleaseSharedAllocator.
Env & DisableTelemetryEvents()
Wraps OrtApi::DisableTelemetryEvents.
Mutable EpDevice that is created by EpApi users.
Definition onnxruntime_cxx_api.h:1134
EpDevice(OrtEpDevice *p)
Take ownership of a pointer created by C API.
Definition onnxruntime_cxx_api.h:1136
EpDevice(OrtEpFactory &ep_factory, ConstHardwareDevice &hardware_device, ConstKeyValuePairs ep_metadata={}, ConstKeyValuePairs ep_options={})
Wraps OrtEpApi::CreateEpDevice.
EpDevice(std::nullptr_t)
No instance is created.
Definition onnxruntime_cxx_api.h:1135
All C++ methods that can fail will throw an exception of this type.
Definition onnxruntime_cxx_api.h:54
const char * what() const noexcept override
Definition onnxruntime_cxx_api.h:59
Exception(const std::string &string, OrtErrorCode code)
Definition onnxruntime_cxx_api.h:55
OrtErrorCode GetOrtErrorCode() const
Definition onnxruntime_cxx_api.h:58
Exception(std::string &&string, OrtErrorCode code)
Definition onnxruntime_cxx_api.h:56
Wrapper around OrtExternalInitializerInfo.
Definition onnxruntime_cxx_api.h:913
ConstExternalInitializerInfo GetConst() const
Wraps OrtApi::CreateExternalInitializerInfo.
Definition onnxruntime_cxx_api.h:921
ExternalInitializerInfo(const char *filepath, int64_t file_offset, size_t byte_size)
Wrapper around CreateExternalInitializerInfo that does not throw an exception.
ExternalInitializerInfo(std::nullptr_t)
Definition onnxruntime_cxx_api.h:917
ExternalInitializerInfo(OrtExternalInitializerInfo *p)
Definition onnxruntime_cxx_api.h:918
static Status Create(const char *filepath, int64_t file_offset, size_t byte_size, ExternalInitializerInfo &out)
IEEE 754 half-precision floating point data type.
Definition onnxruntime_cxx_api.h:271
Float16_t()=default
Default constructor.
Float16_t(float v) noexcept
__ctor from float. Float is converted into float16 16-bit representation.
Definition onnxruntime_cxx_api.h:299
onnxruntime_float16::Float16Impl< Float16_t > Base
Definition onnxruntime_cxx_api.h:281
float ToFloat() const noexcept
Converts float16 to float.
Definition onnxruntime_cxx_api.h:305
static constexpr Float16_t FromBits(uint16_t v) noexcept
Explicit conversion to uint16_t representation of float16.
Definition onnxruntime_cxx_api.h:293
float8e4m3fn (Float8 Floating Point) data type
Definition onnxruntime_cxx_api.h:543
uint8_t value
Definition onnxruntime_cxx_api.h:544
constexpr Float8E4M3FN_t(uint8_t v) noexcept
Definition onnxruntime_cxx_api.h:546
constexpr bool operator==(const Float8E4M3FN_t &rhs) const noexcept
Definition onnxruntime_cxx_api.h:549
constexpr Float8E4M3FN_t() noexcept
Definition onnxruntime_cxx_api.h:545
constexpr bool operator!=(const Float8E4M3FN_t &rhs) const noexcept
Definition onnxruntime_cxx_api.h:550
float8e4m3fnuz (Float8 Floating Point) data type
Definition onnxruntime_cxx_api.h:560
constexpr bool operator==(const Float8E4M3FNUZ_t &rhs) const noexcept
Definition onnxruntime_cxx_api.h:566
uint8_t value
Definition onnxruntime_cxx_api.h:561
constexpr Float8E4M3FNUZ_t() noexcept
Definition onnxruntime_cxx_api.h:562
constexpr bool operator!=(const Float8E4M3FNUZ_t &rhs) const noexcept
Definition onnxruntime_cxx_api.h:567
constexpr Float8E4M3FNUZ_t(uint8_t v) noexcept
Definition onnxruntime_cxx_api.h:563
float8e5m2 (Float8 Floating Point) data type
Definition onnxruntime_cxx_api.h:577
constexpr Float8E5M2_t(uint8_t v) noexcept
Definition onnxruntime_cxx_api.h:580
uint8_t value
Definition onnxruntime_cxx_api.h:578
constexpr bool operator!=(const Float8E5M2_t &rhs) const noexcept
Definition onnxruntime_cxx_api.h:584
constexpr Float8E5M2_t() noexcept
Definition onnxruntime_cxx_api.h:579
constexpr bool operator==(const Float8E5M2_t &rhs) const noexcept
Definition onnxruntime_cxx_api.h:583
float8e5m2fnuz (Float8 Floating Point) data type
Definition onnxruntime_cxx_api.h:594
constexpr Float8E5M2FNUZ_t() noexcept
Definition onnxruntime_cxx_api.h:596
constexpr Float8E5M2FNUZ_t(uint8_t v) noexcept
Definition onnxruntime_cxx_api.h:597
constexpr bool operator!=(const Float8E5M2FNUZ_t &rhs) const noexcept
Definition onnxruntime_cxx_api.h:601
constexpr bool operator==(const Float8E5M2FNUZ_t &rhs) const noexcept
Definition onnxruntime_cxx_api.h:600
uint8_t value
Definition onnxruntime_cxx_api.h:595
Wrapper around OrtGraph.
Definition onnxruntime_cxx_api.h:3267
Graph(OrtGraph *p)
Take ownership of a pointer created by C API.
Definition onnxruntime_cxx_api.h:3269
Graph(std::nullptr_t)
No instance is created.
Definition onnxruntime_cxx_api.h:3268
Wrapper around OrtIoBinding.
Definition onnxruntime_cxx_api.h:2481
UnownedIoBinding GetUnowned() const
Definition onnxruntime_cxx_api.h:2485
ConstIoBinding GetConst() const
Definition onnxruntime_cxx_api.h:2484
IoBinding(Session &session)
IoBinding(std::nullptr_t)
Create an empty object for convenience. Sometimes, we want to initialize members later.
Definition onnxruntime_cxx_api.h:2482
This class wraps a raw pointer OrtKernelContext* that is being passed to the custom kernel Compute() ...
Definition onnxruntime_cxx_api.h:2715
KernelContext(OrtKernelContext *context)
Logger GetLogger() const
ConstValue GetInput(size_t index) const
OrtKernelContext * GetOrtKernelContext() const
Definition onnxruntime_cxx_api.h:2729
void ParallelFor(void(*fn)(void *, size_t), size_t total, size_t num_batch, void *usr_data) const
void * GetGPUComputeStream() const
size_t GetInputCount() const
Ort::Allocator GetAllocator(const OrtMemoryInfo &memory_info) const
size_t GetOutputCount() const
UnownedValue GetOutput(size_t index, const std::vector< int64_t > &dims) const
UnownedValue GetOutput(size_t index, const int64_t *dim_values, size_t dim_count) const
Builder for OrtKernelDef.
Definition onnxruntime_cxx_api.h:3349
KernelDefBuilder & AddTypeConstraint(const char *arg_name, const OrtDataType *data_type)
KernelDefBuilder & SetOutputMemType(size_t output_index, OrtMemType mem_type)
KernelDefBuilder & AddInputOutputMutableAliases(const std::vector< int > &input_indices, const std::vector< int > &output_indices)
KernelDefBuilder & SetInputMemType(size_t input_index, OrtMemType mem_type)
KernelDefBuilder & SetDomain(const char *domain)
KernelDefBuilder & AddInputOutputAliases(const std::vector< int > &input_indices, const std::vector< int > &output_indices)
KernelDefBuilder & AddInputOutputAlias(int input_index, int output_index)
KernelDefBuilder & SetExecutionProvider(const char *ep_name)
KernelDefBuilder & SetOperatorType(const char *op_type)
KernelDefBuilder & AddInputOutputMutableAlias(int input_index, int output_index)
KernelDefBuilder()
Wraps OrtEpApi::CreateKernelDefBuilder.
KernelDefBuilder & AddTypeConstraint(const char *arg_name, const std::vector< const OrtDataType * > &data_types)
KernelDefBuilder(std::nullptr_t)
Create an empty object, must be assigned a valid one to be used.
Definition onnxruntime_cxx_api.h:3351
KernelDefBuilder(OrtKernelDefBuilder *ort_kernel_def_builder)
KernelDefBuilder & SetSinceVersion(int since_version_start, int since_version_end)
Definition onnxruntime_cxx_api.h:3335
KernelDef(OrtKernelDef *p)
Definition onnxruntime_cxx_api.h:3340
KernelDef(std::nullptr_t)
Definition onnxruntime_cxx_api.h:3339
ConstKernelDef GetConst() const
Definition onnxruntime_cxx_api.h:3342
This struct owns the OrtKernInfo* pointer when a copy is made. For convenient wrapping of OrtKernelIn...
Definition onnxruntime_cxx_api.h:2802
KernelInfo(OrtKernelInfo *info)
Take ownership of the instance.
ConstKernelInfo GetConst() const
Definition onnxruntime_cxx_api.h:2807
detail::KernelInfoImpl< OrtKernelInfo > Base
Definition onnxruntime_cxx_api.h:2803
KernelInfo(std::nullptr_t)
Create an empty instance to initialize later.
Definition onnxruntime_cxx_api.h:2805
Registry for kernels supported by an EP.
Definition onnxruntime_cxx_api.h:3376
KernelRegistry()
< Wrapper around OrtEpApi::CreateKernelRegistry
KernelRegistry(std::nullptr_t)
Take ownership of a pointer created with the C API.
Definition onnxruntime_cxx_api.h:3381
Status AddKernel(const OrtKernelDef *kernel_def, OrtKernelCreateFunc kernel_create_func, void *kernel_create_func_state)
KernelRegistry(OrtKernelRegistry *ort_kernel_registry)
Wraps KernelRegistry_AddKernel.
Wrapper around OrtKeyValuePairs.
Definition onnxruntime_cxx_api.h:950
KeyValuePairs()
Wraps OrtApi::CreateKeyValuePairs.
void Add(const char *key, const char *value)
Wraps OrtApi::AddKeyValuePair.
KeyValuePairs(const std::unordered_map< std::string, std::string > &kv_pairs)
Wraps OrtApi::CreateKeyValuePairs and OrtApi::AddKeyValuePair.
void Remove(const char *key)
Wraps OrtApi::RemoveKeyValuePair.
KeyValuePairs(std::nullptr_t)
Definition onnxruntime_cxx_api.h:951
ConstKeyValuePairs GetConst() const
Definition onnxruntime_cxx_api.h:967
KeyValuePairs(OrtKeyValuePairs *p)
Take ownership of a pointer created by C API.
Definition onnxruntime_cxx_api.h:953
This class represents an ONNX Runtime logger that can be used to log information with an associated s...
Definition onnxruntime_cxx_api.h:2637
Logger(Logger &&v) noexcept=default
Logger & operator=(Logger &&v) noexcept=default
Logger & operator=(const Logger &)=default
~Logger()=default
Logger(const Logger &)=default
Logger()=default
Logger(std::nullptr_t)
Definition onnxruntime_cxx_api.h:2646
Logger(const OrtLogger *logger)
OrtLoggingLevel GetLoggingSeverityLevel() const noexcept
LoraAdapter holds a set of Lora Parameters loaded from a single file.
Definition onnxruntime_cxx_api.h:1229
static LoraAdapter CreateLoraAdapter(const std::basic_string< char > &adapter_path, OrtAllocator *allocator)
Wraps OrtApi::CreateLoraAdapter.
LoraAdapter(std::nullptr_t)
Definition onnxruntime_cxx_api.h:1233
static LoraAdapter CreateLoraAdapterFromArray(const void *bytes, size_t num_bytes, OrtAllocator *allocator)
Wraps OrtApi::CreateLoraAdapterFromArray.
Wrapper around OrtMapTypeInfo.
Definition onnxruntime_cxx_api.h:1880
ConstMapTypeInfo GetConst() const
Definition onnxruntime_cxx_api.h:1886
MapTypeInfo(OrtMapTypeInfo *p)
Used for interop with the C API.
Definition onnxruntime_cxx_api.h:1885
MapTypeInfo(std::nullptr_t)
Create an empty MapTypeInfo object, must be assigned a valid one to be used.
Definition onnxruntime_cxx_api.h:1884
Represents native memory allocation coming from one of the OrtAllocators registered with OnnxRuntime....
Definition onnxruntime_cxx_api.h:1011
MemoryAllocation(MemoryAllocation &&) noexcept
MemoryAllocation & operator=(const MemoryAllocation &)=delete
MemoryAllocation(const MemoryAllocation &)=delete
MemoryAllocation(OrtAllocator *allocator, void *p, size_t size)
size_t size() const
Definition onnxruntime_cxx_api.h:1020
Wrapper around OrtMemoryInfo.
Definition onnxruntime_cxx_api.h:995
MemoryInfo(const char *name, OrtAllocatorType type, int id, OrtMemType mem_type)
MemoryInfo(std::nullptr_t)
No instance is created.
Definition onnxruntime_cxx_api.h:997
MemoryInfo(OrtMemoryInfo *p)
Take ownership of a pointer created by C API.
Definition onnxruntime_cxx_api.h:998
static MemoryInfo CreateCpu(OrtAllocatorType type, OrtMemType mem_type1)
ConstMemoryInfo GetConst() const
Definition onnxruntime_cxx_api.h:1002
MemoryInfo(const char *name, OrtMemoryInfoDeviceType device_type, uint32_t vendor_id, uint32_t device_id, OrtDeviceMemoryType mem_type, size_t alignment, OrtAllocatorType allocator_type)
Wrapper around CreateMemoryInfo_V2.
Options object used when compiling a model.
Definition onnxruntime_cxx_api.h:1477
ModelCompilationOptions & SetOutputModelWriteFunc(OrtWriteBufferFunc write_func, void *state)
ModelCompilationOptions & SetEpContextEmbedMode(bool embed_ep_context_in_model)
Wraps OrtApi::ModelCompilationOptions_SetEpContextEmbedMode.
ModelCompilationOptions & SetInputModelFromBuffer(const void *input_model_data, size_t input_model_data_size)
Wraps OrtApi::ModelCompilationOptions_SetInputModelFromBuffer.
ModelCompilationOptions & SetOutputModelBuffer(OrtAllocator *allocator, void **output_model_buffer_ptr, size_t *output_model_buffer_size_ptr)
Wraps OrtApi::ModelCompilationOptions_SetOutputModelBuffer.
ModelCompilationOptions & SetFlags(uint32_t flags)
Wraps OrtApi::ModelCompilationOptions_SetFlags.
ModelCompilationOptions & SetOutputModelExternalInitializersFile(const char *file_path, size_t initializer_size_threshold)
Wraps OrtApi::ModelCompilationOptions_SetOutputModelExternalInitializersFile.
ModelCompilationOptions(std::nullptr_t)
Create an empty ModelCompilationOptions object, must be assigned a valid one to be used.
Definition onnxruntime_cxx_api.h:1481
ModelCompilationOptions(const Env &env, ConstSessionOptions session_options)
Wraps OrtApi::CreateModelCompilationOptionsFromSessionOptions.
ModelCompilationOptions & SetOutputModelPath(const char *output_model_path)
Wraps OrtApi::ModelCompilationOptions_SetOutputModelPath.
ModelCompilationOptions & SetInputModelPath(const char *input_model_path)
Wraps OrtApi::ModelCompilationOptions_SetInputModelPath.
ModelCompilationOptions & SetOutputModelGetInitializerLocationFunc(OrtGetInitializerLocationFunc get_initializer_location_func, void *state)
ModelCompilationOptions & SetEpContextBinaryInformation(const char *output_directory, const char *model_name)
Wraps OrtApi::ModelCompilationOptions_SetEpContextBinaryInformation.
ModelCompilationOptions & SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level)
Wraps OrtApi::ModelCompilationOptions_SetGraphOptimizationLevel.
ModelCompilationOptions(const Env &env, const SessionOptions &session_options)
Wraps OrtApi::CreateModelCompilationOptionsFromSessionOptions.
Wrapper around OrtModel.
Definition onnxruntime_cxx_api.h:3295
Model(const std::vector< DomainOpsetPair > &opsets)
Model(OrtModel *p)
Take ownership of a pointer created by C API.
Definition onnxruntime_cxx_api.h:3299
std::pair< std::string, int > DomainOpsetPair
Definition onnxruntime_cxx_api.h:3296
Model(std::nullptr_t)
No instance is created.
Definition onnxruntime_cxx_api.h:3298
Wrapper around OrtModelMetadata.
Definition onnxruntime_cxx_api.h:1523
AllocatedStringPtr GetDescriptionAllocated(OrtAllocator *allocator) const
Returns a copy of the description.
std::vector< AllocatedStringPtr > GetCustomMetadataMapKeysAllocated(OrtAllocator *allocator) const
Returns a vector of copies of the custom metadata keys.
ModelMetadata(std::nullptr_t)
Create an empty ModelMetadata object, must be assigned a valid one to be used.
Definition onnxruntime_cxx_api.h:1527
AllocatedStringPtr GetGraphDescriptionAllocated(OrtAllocator *allocator) const
Returns a copy of the graph description.
AllocatedStringPtr GetProducerNameAllocated(OrtAllocator *allocator) const
Returns a copy of the producer name.
AllocatedStringPtr GetGraphNameAllocated(OrtAllocator *allocator) const
Returns a copy of the graph name.
AllocatedStringPtr LookupCustomMetadataMapAllocated(const char *key, OrtAllocator *allocator) const
Looks up a value by a key in the Custom Metadata map.
AllocatedStringPtr GetDomainAllocated(OrtAllocator *allocator) const
Returns a copy of the domain name.
int64_t GetVersion() const
Wraps OrtApi::ModelMetadataGetVersion.
Wrapper around OrtNode.
Definition onnxruntime_cxx_api.h:3162
Node(const std::string &operator_name, const std::string &operator_domain, const std::string &node_name, const std::vector< std::string > &input_names, const std::vector< std::string > &output_names)
Node()=default
Node(std::nullptr_t)
No instance is created.
Definition onnxruntime_cxx_api.h:3164
Node(const std::string &operator_name, const std::string &operator_domain, const std::string &node_name, const std::vector< std::string > &input_names, const std::vector< std::string > &output_names, std::vector< OpAttr > &attributes)
Wraps CreateNode. Node takes ownership of attributes on success and updates the OpAttr in attributes ...
Node(OrtNode *p)
Take ownership of a pointer created by C API.
Definition onnxruntime_cxx_api.h:3165
This struct provides life time management for custom op attribute.
Definition onnxruntime_cxx_api.h:2546
OpAttr(const char *name, const void *data, int len, OrtOpAttrType type)
OpAttr()=default
OpAttr(std::nullptr_t)
Definition onnxruntime_cxx_api.h:2551
ConstOpAttr GetConst() const
Definition onnxruntime_cxx_api.h:2554
Create and own custom defined operation.
Definition onnxruntime_cxx_api.h:2813
Op(OrtOp *)
Take ownership of the OrtOp.
static Op Create(const OrtKernelInfo *info, const char *op_name, const char *domain, int version, const char **type_constraint_names, const ONNXTensorElementDataType *type_constraint_values, size_t type_constraint_count, const OpAttr *attr_values, size_t attr_count, size_t input_count, size_t output_count)
Op(std::nullptr_t)
Create an empty Operator object, must be assigned a valid one to be used.
Definition onnxruntime_cxx_api.h:2817
void Invoke(const OrtKernelContext *context, const OrtValue *const *input_values, size_t input_count, OrtValue *const *output_values, size_t output_count)
void Invoke(const OrtKernelContext *context, const Value *input_values, size_t input_count, Value *output_values, size_t output_count)
Definition onnxruntime_cxx_api.h:3202
std::string domain
Definition onnxruntime_cxx_api.h:3203
int64_t version
Definition onnxruntime_cxx_api.h:3204
The PrepackedWeightsContainer.
Definition onnxruntime_cxx_api.h:881
PrepackedWeightsContainer()
Wraps OrtApi::CreatePrepackedWeightsContainer.
PrepackedWeightsContainer(OrtPrepackedWeightsContainer *p)
Definition onnxruntime_cxx_api.h:886
PrepackedWeightsContainer(std::nullptr_t)
Take ownership of a pointer created by C API.
Definition onnxruntime_cxx_api.h:884
RunOptions.
Definition onnxruntime_cxx_api.h:1257
int GetRunLogSeverityLevel() const
Wraps OrtApi::RunOptionsGetRunLogSeverityLevel.
RunOptions & SetTerminate()
Terminates all currently executing Session::Run calls that were made using this RunOptions instance.
RunOptions & SetRunTag(const char *run_tag)
wraps OrtApi::RunOptionsSetRunTag
RunOptions & AddActiveLoraAdapter(const LoraAdapter &adapter)
Add the LoraAdapter to the list of active adapters. The setting does not affect RunWithBinding() call...
RunOptions & UnsetTerminate()
Clears the terminate flag so this RunOptions instance can be used in a new Session::Run call without ...
int GetRunLogVerbosityLevel() const
Wraps OrtApi::RunOptionsGetRunLogVerbosityLevel.
RunOptions(std::nullptr_t)
Create an empty RunOptions object, must be assigned a valid one to be used.
Definition onnxruntime_cxx_api.h:1258
RunOptions & SetRunLogVerbosityLevel(int)
Wraps OrtApi::RunOptionsSetRunLogVerbosityLevel.
RunOptions & SetRunLogSeverityLevel(int)
Wraps OrtApi::RunOptionsSetRunLogSeverityLevel.
RunOptions & AddConfigEntry(const char *config_key, const char *config_value)
Wraps OrtApi::AddRunConfigEntry.
const char * GetRunTag() const
Wraps OrtApi::RunOptionsGetRunTag.
RunOptions()
Wraps OrtApi::CreateRunOptions.
const char * GetConfigEntry(const char *config_key)
Wraps OrtApi::GetRunConfigEntry.
Wrapper around OrtSequenceTypeInfo.
Definition onnxruntime_cxx_api.h:1842
SequenceTypeInfo(std::nullptr_t)
Create an empty SequenceTypeInfo object, must be assigned a valid one to be used.
Definition onnxruntime_cxx_api.h:1846
ConstSequenceTypeInfo GetConst() const
Definition onnxruntime_cxx_api.h:1848
SequenceTypeInfo(OrtSequenceTypeInfo *p)
Used for interop with the C API.
Definition onnxruntime_cxx_api.h:1847
Wrapper around OrtSession.
Definition onnxruntime_cxx_api.h:1743
Session(std::nullptr_t)
Create an empty Session object, must be assigned a valid one to be used. Wraps OrtApi::CreateSession.
Definition onnxruntime_cxx_api.h:1745
static Session CreateModelEditorSession(const Env &env, const void *model_data, size_t model_data_length, const SessionOptions &options)
Wraps OrtModelEditorApi::CreateModelEditorSession.
UnownedSession GetUnowned() const
Definition onnxruntime_cxx_api.h:1774
Session(const Env &env, const char *model_path, const SessionOptions &options, OrtPrepackedWeightsContainer *prepacked_weights_container)
Wraps OrtApi::CreateSessionWithPrepackedWeightsContainer.
Session(const Env &env, const void *model_data, size_t model_data_length, const SessionOptions &options, OrtPrepackedWeightsContainer *prepacked_weights_container)
Wraps OrtApi::CreateSessionFromArrayWithPrepackedWeightsContainer.
Session(const Env &env, const Model &model, const SessionOptions &options)
Wraps OrtModelEditorApi::CreateSessionFromModel.
Session(OrtSession *p)
C API Interop.
Definition onnxruntime_cxx_api.h:1746
static Session CreateModelEditorSession(const Env &env, const char *model_path, const SessionOptions &options)
Wraps OrtModelEditorApi::CreateModelEditorSession.
Session(const Env &env, const char *model_path, const SessionOptions &options)
ConstSession GetConst() const
Definition onnxruntime_cxx_api.h:1773
Session(const Env &env, const void *model_data, size_t model_data_length, const SessionOptions &options)
Wraps OrtApi::CreateSessionFromArray.
Wrapper around OrtSessionOptions.
Definition onnxruntime_cxx_api.h:1465
SessionOptions(std::nullptr_t)
Create an empty SessionOptions object, must be assigned a valid one to be used.
Definition onnxruntime_cxx_api.h:1466
UnownedSessionOptions GetUnowned() const
Definition onnxruntime_cxx_api.h:1469
SessionOptions()
Wraps OrtApi::CreateSessionOptions.
ConstSessionOptions GetConst() const
Definition onnxruntime_cxx_api.h:1470
SessionOptions(OrtSessionOptions *p)
Used for interop with the C API.
Definition onnxruntime_cxx_api.h:1468
Definition onnxruntime_cxx_api.h:2847
SymbolicInteger & operator=(const SymbolicInteger &)=default
SymbolicInteger(const SymbolicInteger &)=default
int64_t AsInt() const
Definition onnxruntime_cxx_api.h:2868
int64_t i_
Definition onnxruntime_cxx_api.h:2875
const char * s_
Definition onnxruntime_cxx_api.h:2876
bool operator==(const SymbolicInteger &dim) const
Definition onnxruntime_cxx_api.h:2856
SymbolicInteger & operator=(SymbolicInteger &&)=default
SymbolicInteger(SymbolicInteger &&)=default
const char * AsSym() const
Definition onnxruntime_cxx_api.h:2869
SymbolicInteger(int64_t i)
Definition onnxruntime_cxx_api.h:2848
SymbolicInteger(const char *s)
Definition onnxruntime_cxx_api.h:2849
bool IsInt() const
Definition onnxruntime_cxx_api.h:2867
Provide access to per-node attributes and input shapes, so one could compute and set output shapes.
Definition onnxruntime_cxx_api.h:2846
Ints GetAttrInts(const char *attr_name)
Strings GetAttrStrings(const char *attr_name)
Status SetOutputShape(size_t indice, const Shape &shape, ONNXTensorElementDataType type=ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)
std::vector< SymbolicInteger > Shape
Definition onnxruntime_cxx_api.h:2881
std::vector< float > Floats
Definition onnxruntime_cxx_api.h:2898
std::string GetAttrString(const char *attr_name)
std::vector< int64_t > Ints
Definition onnxruntime_cxx_api.h:2893
ShapeInferContext(const OrtApi *ort_api, OrtShapeInferContext *ctx)
int64_t GetAttrInt(const char *attr_name)
size_t GetInputCount() const
Definition onnxruntime_cxx_api.h:2887
std::vector< std::string > Strings
Definition onnxruntime_cxx_api.h:2903
Floats GetAttrFloats(const char *attr_name)
const Shape & GetInputShape(size_t indice) const
Definition onnxruntime_cxx_api.h:2885
float GetAttrFloat(const char *attr_name)
The Status that holds ownership of OrtStatus received from C API Use it to safely destroy OrtStatus* ...
Definition onnxruntime_cxx_api.h:797
OrtErrorCode GetErrorCode() const
Status(const Exception &)
Creates status instance out of exception.
bool IsOK() const noexcept
Returns true if instance represents an OK (non-error) status.
Status(OrtStatus *status) noexcept
Takes ownership of OrtStatus instance returned from the C API.
std::string GetErrorMessage() const
Status()=default
Status(const std::exception &)
Creates status instance out of exception.
Status(const char *message, OrtErrorCode code)
Creates status instance out of null-terminated string message.
Status(std::nullptr_t) noexcept
Create an empty object, must be assigned a valid one to be used.
Definition onnxruntime_cxx_api.h:799
Definition onnxruntime_cxx_api.h:1083
SyncStream(OrtSyncStream *p)
Definition onnxruntime_cxx_api.h:1087
SyncStream(std::nullptr_t)
< Create an empty SyncStream object, must be assigned a valid one to be used
Definition onnxruntime_cxx_api.h:1085
The TensorRTOptions (V2)
Definition onnxruntime_cxx_api.h:843
void Update(const std::unordered_map< std::string, std::string > &options)
Wrapper around OrtApi::UpdateTensorRTProviderOptions.
void UpdateWithValue(const char *key, void *value)
Wrapper around OrtApi::GetTensorRTProviderOptionsByName.
std::string GetTensorRTProviderOptionsAsString() const
void * GetOptionByName(const char *name) const
Wrapper around OrtApi::GetTensorRTProviderOptionsAsString.
TensorRTProviderOptions(std::nullptr_t)
Definition onnxruntime_cxx_api.h:844
TensorRTProviderOptions()
Wraps OrtApi::CreateTensorRTProviderOptionsV2.
Wrapper around OrtTensorTypeAndShapeInfo.
Definition onnxruntime_cxx_api.h:1808
TensorTypeAndShapeInfo(std::nullptr_t)
Create an empty TensorTypeAndShapeInfo object, must be assigned a valid one to be used.
Definition onnxruntime_cxx_api.h:1813
ConstTensorTypeAndShapeInfo GetConst() const
Definition onnxruntime_cxx_api.h:1824
TensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo *p)
Used for interop with the C API.
Definition onnxruntime_cxx_api.h:1815
TensorTypeAndShapeInfo(ONNXTensorElementDataType element_type, const std::vector< int64_t > &dims, const std::vector< std::string > *symbolic_dims=nullptr)
The ThreadingOptions.
Definition onnxruntime_cxx_api.h:813
ThreadingOptions & SetGlobalCustomThreadCreationOptions(void *ort_custom_thread_creation_options)
Wraps OrtApi::SetGlobalCustomThreadCreationOptions.
ThreadingOptions()
Wraps OrtApi::CreateThreadingOptions.
ThreadingOptions & SetGlobalInterOpNumThreads(int inter_op_num_threads)
Wraps OrtApi::SetGlobalInterOpNumThreads.
ThreadingOptions & SetGlobalCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn)
Wraps OrtApi::SetGlobalCustomCreateThreadFn.
ThreadingOptions & SetGlobalCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn)
Wraps OrtApi::SetGlobalCustomJoinThreadFn.
ThreadingOptions & SetGlobalSpinControl(int allow_spinning)
Wraps OrtApi::SetGlobalSpinControl.
ThreadingOptions & SetGlobalDenormalAsZero()
Wraps OrtApi::SetGlobalDenormalAsZero.
ThreadingOptions & SetGlobalIntraOpNumThreads(int intra_op_num_threads)
Wraps OrtApi::SetGlobalIntraOpNumThreads.
Type information that may contain either TensorTypeAndShapeInfo or the information about contained se...
Definition onnxruntime_cxx_api.h:1914
static TypeInfo CreateOptionalTypeInfo(ConstTypeInfo contained_type)
static TypeInfo CreateSequenceTypeInfo(ConstTypeInfo sequence_type)
static TypeInfo CreateTensorInfo(ConstTensorTypeAndShapeInfo tensor_info)
static TypeInfo CreateSparseTensorInfo(ConstTensorTypeAndShapeInfo sparse_tensor_info)
TypeInfo(std::nullptr_t)
Create an empty TypeInfo object, must be assigned a valid one to be used.
Definition onnxruntime_cxx_api.h:1919
static TypeInfo CreateMapTypeInfo(ONNXTensorElementDataType key_type, ConstTypeInfo value_type)
ConstTypeInfo GetConst() const
Definition onnxruntime_cxx_api.h:1930
TypeInfo(OrtTypeInfo *p)
C API Interop.
Definition onnxruntime_cxx_api.h:1920
Wrapper around OrtValue.
Definition onnxruntime_cxx_api.h:2270
static Value CreateSparseTensor(const OrtMemoryInfo *info, void *p_data, const Shape &dense_shape, const Shape &values_shape, ONNXTensorElementDataType type)
Creates an OrtValue instance containing SparseTensor. This constructs a sparse tensor that makes use ...
static Value CreateSparseTensor(const OrtMemoryInfo *info, T *p_data, const Shape &dense_shape, const Shape &values_shape)
This is a simple forwarding method to the other overload that helps deducing data type enum value fro...
Value & operator=(Value &&)=default
static Value CreateSparseTensor(OrtAllocator *allocator, const Shape &dense_shape, ONNXTensorElementDataType type)
Creates an instance of OrtValue containing sparse tensor. The created instance has no data....
Value(Value &&)=default
Value(std::nullptr_t)
Create an empty Value object, must be assigned a valid one to be used.
Definition onnxruntime_cxx_api.h:2276
static Value CreateTensor(const OrtMemoryInfo *info, T *p_data, size_t p_data_element_count, const int64_t *shape, size_t shape_len)
Creates a tensor with a user supplied buffer. Wraps OrtApi::CreateTensorWithDataAsOrtValue.
static Value CreateSparseTensor(OrtAllocator *allocator, const Shape &dense_shape)
This is a simple forwarding method to the below CreateSparseTensor. This helps to specify data type e...
static Value CreateTensor(OrtAllocator *allocator, const int64_t *shape, size_t shape_len, ONNXTensorElementDataType type)
Creates an OrtValue with a tensor using the supplied OrtAllocator. Wraps OrtApi::CreateTensorAsOrtVal...
UnownedValue GetUnowned() const
Definition onnxruntime_cxx_api.h:2281
static Value CreateSequence(const std::vector< Value > &values)
Creates an OrtValue with a Sequence Onnx type representation. The API would ref-count the supplied Or...
static Value CreateMap(const Value &keys, const Value &values)
Creates an OrtValue with a Map Onnx type representation. The API would ref-count the supplied OrtValu...
static Value CreateTensor(const OrtMemoryInfo *info, void *p_data, size_t p_data_byte_count, const int64_t *shape, size_t shape_len, ONNXTensorElementDataType type)
Creates a tensor with a user supplied buffer. Wraps OrtApi::CreateTensorWithDataAsOrtValue.
static Value CreateTensor(OrtAllocator *allocator, const int64_t *shape, size_t shape_len)
Creates an OrtValue with a tensor using a supplied OrtAllocator. Wraps OrtApi::CreateTensorAsOrtValue...
static Value CreateOpaque(const char *domain, const char *type_name, const T &value)
Creates an OrtValue wrapping an Opaque type. This is used for experimental support of non-tensor type...
static Value CreateTensor(OrtAllocator *deleter, void *p_data, size_t p_data_byte_count, const int64_t *shape, size_t shape_len, ONNXTensorElementDataType type)
Creates a tensor with a user supplied buffer. Wraps OrtApi::CreateTensorWithDataAndDeleterAsOrtValue.
ConstValue GetConst() const
Definition onnxruntime_cxx_api.h:2280
Definition onnxruntime_cxx_api.h:3194
int64_t index
Definition onnxruntime_cxx_api.h:3198
ConstNode node
Definition onnxruntime_cxx_api.h:3195
Wrapper around OrtValueInfo.
Definition onnxruntime_cxx_api.h:3099
ConstValueInfo GetConst() const
Definition onnxruntime_cxx_api.h:3109
ValueInfo(std::nullptr_t)
Definition onnxruntime_cxx_api.h:3101
ValueInfo(const std::string &name, const ConstTypeInfo &type_info)
ValueInfo(OrtValueInfo *p)
Take ownership of a pointer created by C API.
Definition onnxruntime_cxx_api.h:3103
ValueInfo()=default
Definition onnxruntime_cxx_api.h:759
AllocatedFree(OrtAllocator *allocator)
Definition onnxruntime_cxx_api.h:761
OrtAllocator * allocator_
Definition onnxruntime_cxx_api.h:760
void operator()(void *ptr) const
Definition onnxruntime_cxx_api.h:763
Base & operator=(Base &&v) noexcept
Definition onnxruntime_cxx_api.h:745
constexpr contained_type & operator*() const noexcept
Definition onnxruntime_cxx_api.h:752
typename Unowned< T >::Type contained_type
Definition onnxruntime_cxx_api.h:734
Base(Base &&v) noexcept
Definition onnxruntime_cxx_api.h:744
Base(const Base &)=default
constexpr Base(contained_type *p) noexcept
Definition onnxruntime_cxx_api.h:737
Base & operator=(const Base &)=default
Used internally by the C++ API. C++ wrapper types inherit from this. This is a zero cost abstraction ...
Definition onnxruntime_cxx_api.h:687
Base(Base &&v) noexcept
Definition onnxruntime_cxx_api.h:699
constexpr Base()=default
constexpr contained_type & operator*() const noexcept
Definition onnxruntime_cxx_api.h:707
contained_type * release()
Relinquishes ownership of the contained C object pointer The underlying object is not destroyed.
Definition onnxruntime_cxx_api.h:711
Base(const Base &)=delete
constexpr Base(contained_type *p) noexcept
Definition onnxruntime_cxx_api.h:691
Base & operator=(const Base &)=delete
Base & operator=(Base &&v) noexcept
Definition onnxruntime_cxx_api.h:700
contained_type * p_
Definition onnxruntime_cxx_api.h:718
~Base()
Definition onnxruntime_cxx_api.h:692
T contained_type
Definition onnxruntime_cxx_api.h:688
Definition onnxruntime_cxx_api.h:893
const std::basic_string< char > GetFilePath() const
Definition onnxruntime_cxx_api.h:3209
std::vector< ConstNode > GetNodes() const
std::vector< ConstValueInfo > GetInputs() const
ConstNode GetParentNode() const
int64_t GetOnnxIRVersion() const
std::basic_string< char > GetModelPath() const
Graph GetGraphView(const std::vector< ConstNode > &nodes) const
ModelMetadata GetModelMetadata() const
Wraps OrtApi::Graph_GetModelMetadata.
std::vector< ConstValueInfo > GetInitializers() const
std::string GetName() const
std::vector< ConstValueInfo > GetOutputs() const
std::vector< OperatorSet > GetOperatorSets() const
Definition onnxruntime_cxx_api.h:2449
std::vector< Value > GetOutputValues(OrtAllocator *) const
std::vector< std::string > GetOutputNames(OrtAllocator *) const
std::vector< Value > GetOutputValues() const
std::vector< std::string > GetOutputNames() const
Definition onnxruntime_cxx_api.h:3309
std::pair< int, int > GetSinceVersion() const
Wraps OrtEpApi::KernelDef_GetExecutionProvider.
const char * GetDomain() const
Wraps OrtEpApi::KernelDef_GetSinceVersion.
OrtMemType GetOutputMemType(size_t output_index) const
const char * GetExecutionProvider() const
Wraps OrtEpApi::KernelDef_GetInputMemType.
OrtMemType GetInputMemType(size_t input_index) const
Wraps OrtEpApi::KernelDef_GetOutputMemType.
const char * GetOperatorType() const
< Wraps OrtEpApi::KernelDef_GetOperatorType
Definition onnxruntime_cxx_api.h:3121
std::vector< ConstValueInfo > GetOutputs() const
std::vector< ConstValueInfo > GetImplicitInputs() const
std::string GetName() const
std::string GetDomain() const
std::vector< AttrNameSubgraph > GetSubgraphs() const
ConstGraphImpl< detail::Unowned< const OrtGraph > > GetGraph() const
std::string GetOperatorType() const
std::vector< ConstOpAttr > GetAttributes() const
std::vector< ConstValueInfo > GetInputs() const
Status GetAttributeByName(const std::string &name, ConstOpAttr &attr) const
std::string GetEpName() const
Definition onnxruntime_cxx_api.h:2518
std::string GetName() const
Status GetValue(R &out) const
Status GetTensorAttributeAsOrtValue(Value &) const
Status GetValueArray(std::vector< R > &out) const
OrtOpAttrType GetType() const
Definition onnxruntime_cxx_api.h:1598
std::vector< std::string > GetOutputNames() const
TypeInfo GetInputTypeInfo(size_t index) const
Wraps OrtApi::SessionGetInputTypeInfo.
size_t GetOutputCount() const
Returns the number of model outputs.
std::vector< ValueInfo > GetOutputs() const
int GetOpset(const std::string &domain) const
Wraps OrtApi::SessionGetOpsetForDomain.
uint64_t GetProfilingStartTimeNs() const
Wraps OrtApi::SessionGetProfilingStartTimeNs.
std::vector< std::string > GetOverridableInitializerNames() const
ModelMetadata GetModelMetadata() const
Wraps OrtApi::SessionGetModelMetadata.
size_t GetInputCount() const
Returns the number of model inputs.
TypeInfo GetOutputTypeInfo(size_t index) const
Wraps OrtApi::SessionGetOutputTypeInfo.
std::vector< std::string > GetInputNames() const
AllocatedStringPtr GetOverridableInitializerNameAllocated(size_t index, OrtAllocator *allocator) const
Returns a copy of the overridable initializer name at then specified index.
std::vector< ConstEpDevice > GetEpDeviceForInputs() const
Wrapper for OrtApi::SessionGetEpDeviceForInputs.
AllocatedStringPtr GetOutputNameAllocated(size_t index, OrtAllocator *allocator) const
Returns a copy of output name at then specified index.
size_t GetOverridableInitializerCount() const
Returns the number of inputs that have defaults that can be overridden.
std::vector< ConstMemoryInfo > GetMemoryInfoForOutputs() const
Wrapper for OrtApi::SessionGetMemoryInfoForOutputs.
AllocatedStringPtr GetInputNameAllocated(size_t index, OrtAllocator *allocator) const
Returns a copy of input name at the specified index.
std::vector< ConstMemoryInfo > GetMemoryInfoForInputs() const
Wrapper for OrtApi::SessionGetMemoryInfoForInputs.
std::vector< ValueInfo > GetInputs() const
TypeInfo GetOverridableInitializerTypeInfo(size_t index) const
Wraps OrtApi::SessionGetOverridableInitializerTypeInfo.
Definition onnxruntime_cxx_api.h:1959
void GetStringTensorContent(void *buffer, size_t buffer_length, size_t *offsets, size_t offsets_count) const
The API copies all of the UTF-8 encoded string data contained within a tensor or a sparse tensor into...
void GetStringTensorElement(size_t buffer_length, size_t element_index, void *buffer) const
The API copies UTF-8 encoded bytes for the requested string element contained within a tensor or a sp...
TensorTypeAndShapeInfo GetSparseTensorIndicesTypeShapeInfo(OrtSparseIndicesFormat format) const
The API returns type and shape information for the specified indices. Each supported indices have the...
const void * GetTensorRawData() const
Returns a non-typed pointer to a tensor contained data.
std::string GetStringTensorElement(size_t element_index) const
Returns string tensor UTF-8 encoded string element. Use of this API is recommended over GetStringTens...
size_t GetStringTensorElementLength(size_t element_index) const
The API returns a byte length of UTF-8 encoded string element contained in either a tensor or a spare...
size_t GetStringTensorDataLength() const
This API returns a full length of string data contained within either a tensor or a sparse Tensor....
bool IsSparseTensor() const
Returns true if the OrtValue contains a sparse tensor.
TypeInfo GetTypeInfo() const
The API returns type information for data contained in a tensor. For sparse tensors it returns type i...
const R * GetSparseTensorIndicesData(OrtSparseIndicesFormat indices_format, size_t &num_indices) const
The API retrieves a pointer to the internal indices buffer. The API merely performs a convenience dat...
bool IsTensor() const
Returns true if Value is a tensor, false for other types like map/sequence/etc.
ConstMemoryInfo GetTensorMemoryInfo() const
This API returns information about the memory allocation used to hold data.
size_t GetTensorSizeInBytes() const
Returns the total size of the tensor data in bytes. Throws an exception if the OrtValue does not cont...
const R * GetSparseTensorValues() const
The API returns a pointer to an internal buffer of the sparse tensor containing non-zero values....
TensorTypeAndShapeInfo GetTensorTypeAndShapeInfo() const
The API returns type information for data contained in a tensor. For sparse tensors it returns type i...
Value GetValue(int index, OrtAllocator *allocator) const
size_t GetCount() const
< Return true if OrtValue contains data and returns false if the OrtValue is a None
void GetOpaqueData(const char *domain, const char *type_name, R &) const
Obtains a pointer to a user defined data for experimental purposes.
TensorTypeAndShapeInfo GetSparseTensorValuesTypeAndShapeInfo() const
The API returns type and shape information for stored non-zero values of the sparse tensor....
const R * GetTensorData() const
Returns a const typed pointer to the tensor contained data. No type checking is performed,...
OrtSparseFormat GetSparseFormat() const
The API returns the sparse data format this OrtValue holds in a sparse tensor. If the sparse tensor w...
Definition onnxruntime_cxx_api.h:3064
Status GetInitializer(ConstValue &value) const
< A wrapper around OrtApi::ValueInfo_GetInitializerValue
std::string GetName() const
< A wrapper around OrtApi::GetValueInfoName
bool IsFromOuterScope() const
< A wrapper around OrtApi::ValueInfo_IsFromOuterScope
Status GetExternalInitializerInfo(ExternalInitializerInfo &info) const
< A wrapper around OrtApi::ValueInfo_GetExternalInitializerInfo
bool IsConstantInitializer() const
< A wrapper around OrtApi::ValueInfo_IsConstantInitializer
std::vector< ValueInfoConsumerProducerInfo > GetConsumers() const
< A wrapper around OrtApi::ValueInfo_GetValueConsumers
bool IsGraphOutput() const
< A wrapper around OrtApi::ValueInfo_IsGraphOutput
bool IsRequiredGraphInput() const
< A wrapper around OrtApi::ValueInfo_IsRequiredGraphInput
ConstTypeInfo TypeInfo() const
< A wrapper around OrtApi::GetValueInfoTypeInfo
ValueInfoConsumerProducerInfo GetProducerNode() const
bool IsOptionalGraphInput() const
< A wrapper around OrtApi::ValueInfo_IsOptionalGraphInput
Definition onnxruntime_cxx_api.h:1113
const char * EpName() const
const char * EpVendor() const
ConstKeyValuePairs EpOptions() const
ConstHardwareDevice Device() const
ConstMemoryInfo GetMemoryInfo(OrtDeviceMemoryType memory_type) const
Wraps EpDevice_MemoryInfo.
SyncStream CreateSyncStream(ConstKeyValuePairs stream_options={}) const
ConstKeyValuePairs EpMetadata() const
Definition onnxruntime_cxx_api.h:3238
void SetInputs(std::vector< ValueInfo > &inputs)
void SetOutputs(std::vector< ValueInfo > &outputs)
void AddNode(Node &node)
void AddInitializer(const std::string &name, Value &initializer, bool data_is_external)
Definition onnxruntime_cxx_api.h:1094
OrtHardwareDeviceType Type() const
const char * Vendor() const
ConstKeyValuePairs Metadata() const
Definition onnxruntime_cxx_api.h:2460
void BindOutput(const char *name, const Value &)
void BindInput(const char *name, const Value &)
void BindOutput(const char *name, const OrtMemoryInfo *)
Definition onnxruntime_cxx_api.h:933
void GetKeyValuePairs(std::vector< const char * > &keys, std::vector< const char * > &values) const
std::unordered_map< std::string, std::string > GetKeyValuePairs() const
const char * GetValue(const char *key) const
Definition onnxruntime_cxx_api.h:1866
ONNXTensorElementDataType GetMapKeyType() const
Wraps OrtApi::GetMapKeyType.
TypeInfo GetMapValueType() const
Wraps OrtApi::GetMapValueType.
Definition onnxruntime_cxx_api.h:972
std::string GetAllocatorName() const
Wrapper MemoryInfoGetName.
int GetDeviceId() const
Wrapper MemoryInfoGetId.
OrtMemType GetMemoryType() const
Wrapper MemoryInfoGetMemType.
OrtDeviceMemoryType GetDeviceMemoryType() const
Wrapper MemoryInfoGetDeviceMemType.
OrtMemoryInfoDeviceType GetDeviceType() const
Wrapper MemoryInfoGetDeviceType.
OrtAllocatorType GetAllocatorType() const
Wrapper MemoryInfoGetType.
uint32_t GetVendorId() const
Wrapper MemoryInfoGetVendorId.
bool operator==(const MemoryInfoImpl< U > &o) const
Definition onnxruntime_cxx_api.h:3278
void AddGraph(Graph &graph)
Definition onnxruntime_cxx_api.h:1853
TypeInfo GetOptionalElementType() const
Wraps OrtApi::CastOptionalTypeToContainedTypeInfo.
Definition onnxruntime_cxx_api.h:1942
const char ** str
Definition onnxruntime_cxx_api.h:1947
const int64_t * values_shape
Definition onnxruntime_cxx_api.h:1943
size_t values_shape_len
Definition onnxruntime_cxx_api.h:1944
const void * p_data
Definition onnxruntime_cxx_api.h:1946
Definition onnxruntime_cxx_api.h:1829
TypeInfo GetSequenceElementType() const
Wraps OrtApi::GetSequenceElementType.
Definition onnxruntime_cxx_api.h:1656
void SetEpDynamicOptions(const char *const *keys, const char *const *values, size_t kv_len)
Set DynamicOptions for EPs (Execution Providers)
AllocatedStringPtr EndProfilingAllocated(OrtAllocator *allocator)
End profiling and return a copy of the profiling file name.
void FinalizeModelEditorSession(const Model &model, const SessionOptions &options, OrtPrepackedWeightsContainer *prepacked_weights_container=nullptr)
void Run(const RunOptions &run_options, const IoBinding &)
Wraps OrtApi::RunWithBinding.
void RunAsync(const RunOptions &run_options, const char *const *input_names, const Value *input_values, size_t input_count, const char *const *output_names, Value *output_values, size_t output_count, RunAsyncCallbackFn callback, void *user_data)
Run the model asynchronously in a thread owned by intra op thread pool.
std::vector< Value > Run(const RunOptions &run_options, const char *const *input_names, const Value *input_values, size_t input_count, const char *const *output_names, size_t output_count)
Run the model returning results in an Ort allocated vector.
void Run(const RunOptions &run_options, const char *const *input_names, const Value *input_values, size_t input_count, const char *const *output_names, Value *output_values, size_t output_count)
Run the model returning results in user provided outputs Same as Run(const RunOptions&,...
Definition onnxruntime_cxx_api.h:1953
const int64_t * shape
Definition onnxruntime_cxx_api.h:1954
size_t shape_len
Definition onnxruntime_cxx_api.h:1955
Definition onnxruntime_cxx_api.h:3393
Status StoreWeightData(void **buffer_data_ptrs, size_t *buffer_sizes, size_t num_buffers)
Definition onnxruntime_cxx_api.h:1075
void * GetHandle()
Wraps SyncStream_GetHandle.
Definition onnxruntime_cxx_api.h:1779
size_t GetElementCount() const
Wraps OrtApi::GetTensorShapeElementCount.
void GetDimensions(int64_t *values, size_t values_count) const
Wraps OrtApi::GetDimensions.
std::vector< int64_t > GetShape() const
Uses GetDimensionsCount & GetDimensions to return a std::vector of the shape.
std::vector< const char * > GetSymbolicDimensions() const
void GetSymbolicDimensions(const char **values, size_t values_count) const
Wraps OrtApi::GetSymbolicDimensions.
size_t GetDimensionsCount() const
Wraps OrtApi::GetDimensionsCount.
ONNXTensorElementDataType GetElementType() const
Wraps OrtApi::GetTensorElementType.
bool HasShape() const
Wraps OrtApi::TensorTypeAndShape_HasShape.
Definition onnxruntime_cxx_api.h:1891
ONNXType GetONNXType() const
ConstSequenceTypeInfo GetSequenceTypeInfo() const
Wraps OrtApi::CastTypeInfoToSequenceTypeInfo.
ConstMapTypeInfo GetMapTypeInfo() const
Wraps OrtApi::CastTypeInfoToMapTypeInfo.
ConstOptionalTypeInfo GetOptionalTypeInfo() const
wraps OrtApi::CastTypeInfoToOptionalTypeInfo
ConstTensorTypeAndShapeInfo GetTensorTypeAndShapeInfo() const
Wraps OrtApi::CastTypeInfoToTensorInfo.
This is a tagging template type. Use it with Base<T> to indicate that the C++ interface object has no...
Definition onnxruntime_cxx_api.h:663
T Type
Definition onnxruntime_cxx_api.h:664
Definition onnxruntime_cxx_api.h:2128
void FillStringTensorElement(const char *s, size_t index)
Set a single string in a string tensor.
R * GetTensorMutableData()
Returns a non-const typed pointer to an OrtValue/Tensor contained buffer No type checking is performe...
R & At(const std::vector< int64_t > &location)
void UseBlockSparseIndices(const Shape &indices_shape, int32_t *indices_data)
Supplies BlockSparse format specific indices and marks the contained sparse tensor as being a BlockSp...
void FillSparseTensorBlockSparse(const OrtMemoryInfo *data_mem_info, const OrtSparseValuesParam &values, const Shape &indices_shape, const int32_t *indices_data)
The API will allocate memory using the allocator instance supplied to the CreateSparseTensor() API an...
void * GetTensorMutableRawData()
Returns a non-typed non-const pointer to a tensor contained data.
void UseCooIndices(int64_t *indices_data, size_t indices_num)
Supplies COO format specific indices and marks the contained sparse tensor as being a COO format tens...
void FillSparseTensorCoo(const OrtMemoryInfo *data_mem_info, const OrtSparseValuesParam &values_param, const int64_t *indices_data, size_t indices_num)
The API will allocate memory using the allocator instance supplied to the CreateSparseTensor() API an...
void FillStringTensor(const char *const *s, size_t s_len)
Set all strings at once in a string tensor.
void UseCsrIndices(int64_t *inner_data, size_t inner_num, int64_t *outer_data, size_t outer_num)
Supplies CSR format specific indices and marks the contained sparse tensor as being a CSR format tens...
void FillSparseTensorCsr(const OrtMemoryInfo *data_mem_info, const OrtSparseValuesParam &values, const int64_t *inner_indices_data, size_t inner_indices_num, const int64_t *outer_indices_data, size_t outer_indices_num)
The API will allocate memory using the allocator instance supplied to the CreateSparseTensor() API an...
char * GetResizedStringTensorElementBuffer(size_t index, size_t buffer_length)
Allocate if necessary and obtain a pointer to a UTF-8 encoded string element buffer indexed by the fl...
Memory allocation interface.
Definition onnxruntime_c_api.h:346
void(* Free)(struct OrtAllocator *this_, void *p)
Free a block of memory previously allocated with OrtAllocator::Alloc.
Definition onnxruntime_c_api.h:353
const OrtApi *(* GetApi)(uint32_t version)
Get a pointer to the requested version of the OrtApi.
Definition onnxruntime_c_api.h:898
Definition onnxruntime_c_api.h:968
const OrtEpApi *(* GetEpApi)(void)
Get the OrtEpApi instance for implementing an execution provider.
Definition onnxruntime_c_api.h:5441
const OrtCompileApi *(* GetCompileApi)(void)
Get the Compile API instance.
Definition onnxruntime_c_api.h:5173
void(* ReleaseTensorRTProviderOptions)(OrtTensorRTProviderOptionsV2 *input)
Release an OrtTensorRTProviderOptionsV2.
Definition onnxruntime_c_api.h:3224
const OrtModelEditorApi *(* GetModelEditorApi)(void)
Get the Model Editor API instance.
Definition onnxruntime_c_api.h:5115
void(* ReleaseCUDAProviderOptions)(OrtCUDAProviderOptionsV2 *input)
Release an OrtCUDAProviderOptionsV2.
Definition onnxruntime_c_api.h:3727
CUDA Provider Options.
Definition onnxruntime_c_api.h:601
The OrtCompileApi struct provides functions to compile ONNX models.
Definition onnxruntime_c_api.h:7224
Definition onnxruntime_c_api.h:6697
int(* GetVariadicInputHomogeneity)(const struct OrtCustomOp *op)
Definition onnxruntime_c_api.h:6743
OrtCustomOpInputOutputCharacteristic(* GetOutputCharacteristic)(const struct OrtCustomOp *op, size_t index)
Definition onnxruntime_c_api.h:6727
size_t(* GetInputTypeCount)(const struct OrtCustomOp *op)
Definition onnxruntime_c_api.h:6715
int(* GetVariadicOutputMinArity)(const struct OrtCustomOp *op)
Definition onnxruntime_c_api.h:6747
size_t(* GetAliasMap)(int **input_index, int **output_index)
Definition onnxruntime_c_api.h:6780
int(* GetStartVersion)(const struct OrtCustomOp *op)
Definition onnxruntime_c_api.h:6765
void(* ReleaseMayInplace)(int *input_index, int *output_index)
Definition onnxruntime_c_api.h:6777
const char *(* GetName)(const struct OrtCustomOp *op)
Definition onnxruntime_c_api.h:6708
size_t(* GetOutputTypeCount)(const struct OrtCustomOp *op)
Definition onnxruntime_c_api.h:6717
void(* KernelDestroy)(void *op_kernel)
Definition onnxruntime_c_api.h:6723
int(* GetVariadicOutputHomogeneity)(const struct OrtCustomOp *op)
Definition onnxruntime_c_api.h:6752
OrtMemType(* GetInputMemoryType)(const struct OrtCustomOp *op, size_t index)
Definition onnxruntime_c_api.h:6734
void *(* CreateKernel)(const struct OrtCustomOp *op, const OrtApi *api, const OrtKernelInfo *info)
Definition onnxruntime_c_api.h:6704
uint32_t version
Definition onnxruntime_c_api.h:6698
ONNXTensorElementDataType(* GetInputType)(const struct OrtCustomOp *op, size_t index)
Definition onnxruntime_c_api.h:6714
void(* ReleaseAliasMap)(int *input_index, int *output_index)
Definition onnxruntime_c_api.h:6781
OrtCustomOpInputOutputCharacteristic(* GetInputCharacteristic)(const struct OrtCustomOp *op, size_t index)
Definition onnxruntime_c_api.h:6726
const char *(* GetExecutionProviderType)(const struct OrtCustomOp *op)
Definition onnxruntime_c_api.h:6711
ONNXTensorElementDataType(* GetOutputType)(const struct OrtCustomOp *op, size_t index)
Definition onnxruntime_c_api.h:6716
int(* GetVariadicInputMinArity)(const struct OrtCustomOp *op)
Definition onnxruntime_c_api.h:6738
OrtStatusPtr(* InferOutputShapeFn)(const struct OrtCustomOp *op, OrtShapeInferContext *)
Definition onnxruntime_c_api.h:6762
int(* GetEndVersion)(const struct OrtCustomOp *op)
Definition onnxruntime_c_api.h:6766
OrtStatusPtr(* CreateKernelV2)(const struct OrtCustomOp *op, const OrtApi *api, const OrtKernelInfo *info, void **kernel)
Definition onnxruntime_c_api.h:6755
size_t(* GetMayInplace)(int **input_index, int **output_index)
Definition onnxruntime_c_api.h:6773
OrtStatusPtr(* KernelComputeV2)(void *op_kernel, OrtKernelContext *context)
Definition onnxruntime_c_api.h:6760
void(* KernelCompute)(void *op_kernel, OrtKernelContext *context)
Definition onnxruntime_c_api.h:6722
The OrtEpApi struct provides functions that are relevant to the implementation of an execution provid...
Definition onnxruntime_ep_c_api.h:429
The OrtEpFactory provides functions to create and manage execution providers.
Definition onnxruntime_ep_c_api.h:1325
The OrtEp struct provides functions to implement for an execution provider.
Definition onnxruntime_ep_c_api.h:1007
MIGraphX Provider Options.
Definition onnxruntime_c_api.h:805
The OrtModelEditorApi struct provides functions to create or edit an ONNX model.
Definition onnxruntime_c_api.h:6795
OpenVINO Provider Options.
Definition onnxruntime_c_api.h:844
ROCM Provider Options.
Definition onnxruntime_c_api.h:688
TensorRT Provider Options.
Definition onnxruntime_c_api.h:777