ONNX Runtime
Loading...
Searching...
No Matches
onnxruntime_cxx_api.h
1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4// Summary: The Ort C++ API is a header only wrapper around the Ort C API.
5//
6// The C++ API simplifies usage by returning values directly instead of error codes, throwing exceptions on errors
7// and automatically releasing resources in the destructors. The primary purpose of C++ API is exception safety so
8// all the resources follow RAII and do not leak memory.
9//
10// Each of the C++ wrapper classes holds only a pointer to the C internal object. Treat them like smart pointers.
11// To create an empty object, pass 'nullptr' to the constructor (for example, Env e{nullptr};). However, you can't use them
12// until you assign an instance that actually holds an underlying object.
13//
14// For Ort objects only move assignment between objects is allowed, there are no copy constructors.
15// Some objects have explicit 'Clone' methods for this purpose.
16//
17// ConstXXXX types are copyable since they do not own the underlying C object, so you can pass them to functions as arguments
18// by value or by reference. ConstXXXX types are restricted to const only interfaces.
19//
20// UnownedXXXX are similar to ConstXXXX but also allow non-const interfaces.
21//
22// The lifetime of the corresponding owning object must eclipse the lifetimes of the ConstXXXX/UnownedXXXX types. They exists so you do not
23// have to fallback to C types and the API with the usual pitfalls. In general, do not use C API from your C++ code.
24
25#pragma once
26#include "onnxruntime_c_api.h"
27#include "onnxruntime_float16.h"
28
29#include <array>
30#include <cstddef>
31#include <cstdio>
32#include <memory>
33#include <stdexcept>
34#include <string>
35#include <type_traits>
36#include <unordered_map>
37#include <utility>
38#include <variant>
39#include <vector>
40
41#ifdef ORT_NO_EXCEPTIONS
42#include <iostream>
43#endif
44
48namespace Ort {
49
54struct Exception : std::exception {
55 Exception(const std::string& string, OrtErrorCode code) : message_{string}, code_{code} {}
56 Exception(std::string&& string, OrtErrorCode code) : message_{std::move(string)}, code_{code} {}
57
58 OrtErrorCode GetOrtErrorCode() const { return code_; }
59 const char* what() const noexcept override { return message_.c_str(); }
60
61 private:
62 std::string message_;
63 OrtErrorCode code_;
64};
65
66#ifdef ORT_NO_EXCEPTIONS
67// The #ifndef is for the very special case where the user of this library wants to define their own way of handling errors.
68// NOTE: This header expects control flow to not continue after calling ORT_CXX_API_THROW
69#ifndef ORT_CXX_API_THROW
70#define ORT_CXX_API_THROW(string, code) \
71 do { \
72 std::cerr << Ort::Exception(string, code) \
73 .what() \
74 << std::endl; \
75 abort(); \
76 } while (false)
77#endif
78#else
79#define ORT_CXX_API_THROW(string, code) \
80 throw Ort::Exception(string, code)
81#endif
82
83#ifdef ORT_API_MANUAL_INIT
84// If the macro ORT_API_MANUAL_INIT is defined, no static initialization
85// will be performed. Instead, users must call InitApi() before using the
86// ORT C++ APIs..
87//
88// InitApi() sets the global API object using the default initialization
89// logic. Users call this to initialize the ORT C++ APIs at a time that
90// makes sense in their program.
91inline void InitApi() noexcept;
92
93// InitApi(const OrtApi*) is used by custom operator libraries that are not
94// linked to onnxruntime. It sets the global API object, which is required
95// by the ORT C++ APIs.
96//
97// Example mycustomop.cc:
98//
99// #define ORT_API_MANUAL_INIT
100// #include <onnxruntime_cxx_api.h>
101// #undef ORT_API_MANUAL_INIT
102//
103// OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options, const OrtApiBase* api_base) {
104// Ort::InitApi(api_base->GetApi(ORT_API_VERSION));
105// // ...
106// }
107//
108inline void InitApi(const OrtApi* api) noexcept;
109#endif
110
111namespace detail {
112// This is used internally by the C++ API. This class holds the global
113// variable that points to the OrtApi.
114struct Global {
115 static const OrtApi* Api(const OrtApi* newValue = nullptr) noexcept {
116 // This block-level static will be initialized once when this function is
117 // first executed, delaying the call to DefaultInit() until it is first needed.
118 //
119 // When ORT_API_MANUAL_INIT is not defined, DefaultInit() calls
120 // OrtGetApiBase()->GetApi(), which may result in a shared library being
121 // loaded.
122 //
123 // Using a block-level static instead of a class-level static helps
124 // avoid issues with static initialization order and dynamic libraries
125 // loading other dynamic libraries.
126 //
127 // This makes it safe to include the C++ API headers in a shared library
128 // that is delay loaded or delay loads its dependencies.
129 //
130 // This DOES NOT make it safe to _use_ arbitrary ORT C++ APIs when
131 // initializing static members, however.
132 static const OrtApi* api = DefaultInit();
133
134 if (newValue) {
135 api = newValue;
136 }
137
138 return api;
139 }
140
141 private:
142 // Has different definitions based on ORT_API_MANUAL_INIT
143 static const OrtApi* DefaultInit() noexcept;
144
145#ifdef ORT_API_MANUAL_INIT
146 // Public APIs to set the OrtApi* to use.
147 friend void ::Ort::InitApi() noexcept;
148 friend void ::Ort::InitApi(const OrtApi*) noexcept;
149#endif
150};
151} // namespace detail
152
153#ifdef ORT_API_MANUAL_INIT
154
155// See comments on declaration above for usage.
156inline void InitApi(const OrtApi* api) noexcept { detail::Global::Api(api); }
157inline void InitApi() noexcept { InitApi(OrtGetApiBase()->GetApi(ORT_API_VERSION)); }
158
159#ifdef _MSC_VER
160// If you get a linker error about a mismatch here, you are trying to
161// link two compilation units that have different definitions for
162// ORT_API_MANUAL_INIT together. All compilation units must agree on the
163// definition of ORT_API_MANUAL_INIT.
164#pragma detect_mismatch("ORT_API_MANUAL_INIT", "enabled")
165#endif
166
167inline const OrtApi* detail::Global::DefaultInit() noexcept {
168 // When ORT_API_MANUAL_INIT is defined, there's no default init that can
169 // be done.
170 return nullptr;
171}
172
173#else // ORT_API_MANUAL_INIT
174
175#ifdef _MSC_VER
176// If you get a linker error about a mismatch here, you are trying to link
177// two compilation units that have different definitions for
178// ORT_API_MANUAL_INIT together. All compilation units must agree on the
179// definition of ORT_API_MANUAL_INIT.
180#pragma detect_mismatch("ORT_API_MANUAL_INIT", "disabled")
181#endif
182
183inline const OrtApi* detail::Global::DefaultInit() noexcept {
185}
186#endif // ORT_API_MANUAL_INIT
187
189inline const OrtApi& GetApi() noexcept { return *detail::Global::Api(); }
190
195std::string GetVersionString();
196
202std::string GetBuildInfoString();
203
209std::vector<std::string> GetAvailableProviders();
210
216 auto* api = GetApi().GetModelEditorApi();
217 if (api == nullptr) {
218 // minimal build
219 ORT_CXX_API_THROW("Model Editor API is not available in this build", ORT_FAIL);
220 }
221
222 return *api;
223}
224
230 auto* api = GetApi().GetCompileApi();
231 if (api == nullptr) {
232 // minimal build
233 ORT_CXX_API_THROW("Compile API is not available in this build", ORT_FAIL);
234 }
235
236 return *api;
237}
238
244 auto* api = GetApi().GetInteropApi();
245 if (api == nullptr) {
246 // minimal build
247 ORT_CXX_API_THROW("Interop API is not available in this build", ORT_FAIL);
248 }
249
250 return *api;
251}
252
257inline const OrtEpApi& GetEpApi() {
258 auto* api = GetApi().GetEpApi();
259 if (api == nullptr) {
260 // minimal build
261 ORT_CXX_API_THROW("EP API is not available in this build", ORT_FAIL);
262 }
263
264 return *api;
265}
266
285struct Float16_t : onnxruntime_float16::Float16Impl<Float16_t> {
286 private:
292 constexpr explicit Float16_t(uint16_t v) noexcept { val = v; }
293
294 public:
295 using Base = onnxruntime_float16::Float16Impl<Float16_t>;
296
300 Float16_t() = default;
301
307 constexpr static Float16_t FromBits(uint16_t v) noexcept { return Float16_t(v); }
308
313 explicit Float16_t(float v) noexcept { val = Base::ToUint16Impl(v); }
314
319 float ToFloat() const noexcept { return Base::ToFloatImpl(); }
320
325 using Base::IsNegative;
326
331 using Base::IsNaN;
332
337 using Base::IsFinite;
338
343 using Base::IsPositiveInfinity;
344
349 using Base::IsNegativeInfinity;
350
355 using Base::IsInfinity;
356
361 using Base::IsNaNOrZero;
362
367 using Base::IsNormal;
368
373 using Base::IsSubnormal;
374
379 using Base::Abs;
380
385 using Base::Negate;
386
395 using Base::AreZero;
396
400 explicit operator float() const noexcept { return ToFloat(); }
401
402 using Base::operator==;
403 using Base::operator!=;
404 using Base::operator<;
405};
406
407static_assert(sizeof(Float16_t) == sizeof(uint16_t), "Sizes must match");
408
427struct BFloat16_t : onnxruntime_float16::BFloat16Impl<BFloat16_t> {
428 private:
436 constexpr explicit BFloat16_t(uint16_t v) noexcept { val = v; }
437
438 public:
439 using Base = onnxruntime_float16::BFloat16Impl<BFloat16_t>;
440
441 BFloat16_t() = default;
442
448 static constexpr BFloat16_t FromBits(uint16_t v) noexcept { return BFloat16_t(v); }
449
454 explicit BFloat16_t(float v) noexcept { val = Base::ToUint16Impl(v); }
455
460 float ToFloat() const noexcept { return Base::ToFloatImpl(); }
461
466 using Base::IsNegative;
467
472 using Base::IsNaN;
473
478 using Base::IsFinite;
479
484 using Base::IsPositiveInfinity;
485
490 using Base::IsNegativeInfinity;
491
496 using Base::IsInfinity;
497
502 using Base::IsNaNOrZero;
503
508 using Base::IsNormal;
509
514 using Base::IsSubnormal;
515
520 using Base::Abs;
521
526 using Base::Negate;
527
536 using Base::AreZero;
537
541 explicit operator float() const noexcept { return ToFloat(); }
542
543 // We do not have an inherited impl for the below operators
544 // as the internal class implements them a little differently
545 bool operator==(const BFloat16_t& rhs) const noexcept;
546 bool operator!=(const BFloat16_t& rhs) const noexcept { return !(*this == rhs); }
547 bool operator<(const BFloat16_t& rhs) const noexcept;
548};
549
550static_assert(sizeof(BFloat16_t) == sizeof(uint16_t), "Sizes must match");
551
558 uint8_t value;
559 constexpr Float8E4M3FN_t() noexcept : value(0) {}
560 constexpr Float8E4M3FN_t(uint8_t v) noexcept : value(v) {}
561 constexpr operator uint8_t() const noexcept { return value; }
562 // nan values are treated like any other value for operator ==, !=
563 constexpr bool operator==(const Float8E4M3FN_t& rhs) const noexcept { return value == rhs.value; };
564 constexpr bool operator!=(const Float8E4M3FN_t& rhs) const noexcept { return value != rhs.value; };
565};
566
567static_assert(sizeof(Float8E4M3FN_t) == sizeof(uint8_t), "Sizes must match");
568
575 uint8_t value;
576 constexpr Float8E4M3FNUZ_t() noexcept : value(0) {}
577 constexpr Float8E4M3FNUZ_t(uint8_t v) noexcept : value(v) {}
578 constexpr operator uint8_t() const noexcept { return value; }
579 // nan values are treated like any other value for operator ==, !=
580 constexpr bool operator==(const Float8E4M3FNUZ_t& rhs) const noexcept { return value == rhs.value; };
581 constexpr bool operator!=(const Float8E4M3FNUZ_t& rhs) const noexcept { return value != rhs.value; };
582};
583
584static_assert(sizeof(Float8E4M3FNUZ_t) == sizeof(uint8_t), "Sizes must match");
585
592 uint8_t value;
593 constexpr Float8E5M2_t() noexcept : value(0) {}
594 constexpr Float8E5M2_t(uint8_t v) noexcept : value(v) {}
595 constexpr operator uint8_t() const noexcept { return value; }
596 // nan values are treated like any other value for operator ==, !=
597 constexpr bool operator==(const Float8E5M2_t& rhs) const noexcept { return value == rhs.value; };
598 constexpr bool operator!=(const Float8E5M2_t& rhs) const noexcept { return value != rhs.value; };
599};
600
601static_assert(sizeof(Float8E5M2_t) == sizeof(uint8_t), "Sizes must match");
602
609 uint8_t value;
610 constexpr Float8E5M2FNUZ_t() noexcept : value(0) {}
611 constexpr Float8E5M2FNUZ_t(uint8_t v) noexcept : value(v) {}
612 constexpr operator uint8_t() const noexcept { return value; }
613 // nan values are treated like any other value for operator ==, !=
614 constexpr bool operator==(const Float8E5M2FNUZ_t& rhs) const noexcept { return value == rhs.value; };
615 constexpr bool operator!=(const Float8E5M2FNUZ_t& rhs) const noexcept { return value != rhs.value; };
616};
617
618static_assert(sizeof(Float8E5M2FNUZ_t) == sizeof(uint8_t), "Sizes must match");
619
620namespace detail {
621// This is used internally by the C++ API. This macro is to make it easy to generate overloaded methods for all of the various OrtRelease* functions for every Ort* type
622// This can't be done in the C API since C doesn't have function overloading.
623#define ORT_DEFINE_RELEASE(NAME) \
624 inline void OrtRelease(Ort##NAME* ptr) { GetApi().Release##NAME(ptr); }
625
626#define ORT_DEFINE_RELEASE_FROM_API_STRUCT(NAME, API_GETTER) \
627 inline void OrtRelease(Ort##NAME* ptr) { API_GETTER().Release##NAME(ptr); }
628
629ORT_DEFINE_RELEASE(Allocator);
630ORT_DEFINE_RELEASE(ArenaCfg);
631ORT_DEFINE_RELEASE(CustomOpDomain);
632ORT_DEFINE_RELEASE(Env);
633ORT_DEFINE_RELEASE(ExternalInitializerInfo);
634ORT_DEFINE_RELEASE(Graph);
635ORT_DEFINE_RELEASE(IoBinding);
636ORT_DEFINE_RELEASE(KernelInfo);
637ORT_DEFINE_RELEASE(KeyValuePairs);
638ORT_DEFINE_RELEASE(LoraAdapter);
639ORT_DEFINE_RELEASE(MemoryInfo);
640ORT_DEFINE_RELEASE(MapTypeInfo);
641ORT_DEFINE_RELEASE(Model);
642ORT_DEFINE_RELEASE(ModelMetadata);
643ORT_DEFINE_RELEASE(Node);
644ORT_DEFINE_RELEASE(Op);
645ORT_DEFINE_RELEASE(OpAttr);
646ORT_DEFINE_RELEASE(PrepackedWeightsContainer);
647ORT_DEFINE_RELEASE(RunOptions);
648ORT_DEFINE_RELEASE(Session);
649ORT_DEFINE_RELEASE(SessionOptions);
650ORT_DEFINE_RELEASE(SequenceTypeInfo);
651ORT_DEFINE_RELEASE(Status);
652ORT_DEFINE_RELEASE(SyncStream);
653ORT_DEFINE_RELEASE(TensorTypeAndShapeInfo);
654ORT_DEFINE_RELEASE(ThreadingOptions);
655ORT_DEFINE_RELEASE(TypeInfo);
656ORT_DEFINE_RELEASE(Value);
657ORT_DEFINE_RELEASE(ValueInfo);
658
659ORT_DEFINE_RELEASE_FROM_API_STRUCT(ModelCompilationOptions, GetCompileApi);
660ORT_DEFINE_RELEASE_FROM_API_STRUCT(EpDevice, GetEpApi);
661ORT_DEFINE_RELEASE_FROM_API_STRUCT(KernelDef, GetEpApi);
662ORT_DEFINE_RELEASE_FROM_API_STRUCT(KernelDefBuilder, GetEpApi);
663ORT_DEFINE_RELEASE_FROM_API_STRUCT(KernelRegistry, GetEpApi);
664ORT_DEFINE_RELEASE_FROM_API_STRUCT(OpSchema, GetEpApi);
665ORT_DEFINE_RELEASE_FROM_API_STRUCT(ProfilingEvent, GetEpApi);
666
667// This is defined explicitly since OrtTensorRTProviderOptionsV2 is not a C API type,
668// but the struct has V2 in its name to indicate that it is the second version of the options.
671
672#undef ORT_DEFINE_RELEASE
673#undef ORT_DEFINE_RELEASE_FROM_API_STRUCT
674
678template <typename T>
679struct Unowned {
680 using Type = T;
681};
682
702template <typename T>
703struct Base {
704 using contained_type = T;
705
706 constexpr Base() = default;
707 constexpr explicit Base(contained_type* p) noexcept : p_{p} {}
709 OrtRelease(p_);
710 }
711
712 Base(const Base&) = delete;
713 Base& operator=(const Base&) = delete;
714
715 Base(Base&& v) noexcept : p_{v.p_} { v.p_ = nullptr; }
716 Base& operator=(Base&& v) noexcept {
717 OrtRelease(p_);
718 p_ = v.release();
719 return *this;
720 }
721
722 constexpr operator contained_type*() const noexcept { return p_; }
723 constexpr contained_type& operator*() const noexcept { return *p_; }
724
728 T* p = p_;
729 p_ = nullptr;
730 return p;
731 }
732
733 protected:
735};
736
737// Undefined. For const types use Base<Unowned<const T>>
738template <typename T>
739struct Base<const T>;
740
748template <typename T>
749struct Base<Unowned<T>> {
751
752 constexpr Base() = default;
753 constexpr explicit Base(contained_type* p) noexcept : p_{p} {}
754
755 ~Base() = default;
756
757 Base(const Base&) = default;
758 Base& operator=(const Base&) = default;
759
760 Base(Base&& v) noexcept : p_{v.p_} { v.p_ = nullptr; }
761 Base& operator=(Base&& v) noexcept {
762 p_ = nullptr;
763 std::swap(p_, v.p_);
764 return *this;
765 }
766
767 constexpr operator contained_type*() const noexcept { return p_; }
768 constexpr contained_type& operator*() const noexcept { return *p_; }
769
770 protected:
772};
773
774// Light functor to release memory with OrtAllocator
777 explicit AllocatedFree(OrtAllocator* allocator)
778 : allocator_(allocator) {}
779 void operator()(void* ptr) const {
780 if (ptr) allocator_->Free(allocator_, ptr);
781 }
782};
783
784} // namespace detail
785
786struct AllocatorWithDefaultOptions;
787struct Env;
788struct EpDevice;
789struct ExternalInitializerInfo;
790struct Graph;
791struct Model;
792struct Node;
793struct ModelMetadata;
794struct TypeInfo;
795struct PrepackedWeightsContainer;
796struct Session;
797struct SessionOptions;
798struct SyncStream;
799struct TensorRTProviderOptions;
800struct Value;
801struct ValueInfo;
802
807using AllocatedStringPtr = std::unique_ptr<char, detail::AllocatedFree>;
808
813struct Status : detail::Base<OrtStatus> {
814 Status() = default; // Same as with std::nullptr_t. But can be used in re-sizable containers and represent success.
815 explicit Status(std::nullptr_t) noexcept {}
816 explicit Status(OrtStatus* status) noexcept;
817 explicit Status(const Exception&);
818 explicit Status(const std::exception&);
819 Status(const char* message, OrtErrorCode code);
820 std::string GetErrorMessage() const;
822 bool IsOK() const noexcept;
823};
824
854
859struct TensorRTProviderOptions : detail::Base<OrtTensorRTProviderOptionsV2> {
860 TensorRTProviderOptions(std::nullptr_t) {}
864 void Update(const std::unordered_map<std::string, std::string>& options);
866 void UpdateWithValue(const char* key, void* value);
867
869 void* GetOptionByName(const char* name) const;
872};
873
878struct CUDAProviderOptions : detail::Base<OrtCUDAProviderOptionsV2> {
879 CUDAProviderOptions(std::nullptr_t) {}
883 void Update(const std::unordered_map<std::string, std::string>& options);
887 void UpdateWithValue(const char* key, void* value);
889 void* GetOptionByName(const char* name) const;
890};
891
906
907namespace detail {
908template <typename T>
910 using B = Base<T>;
911 using B::B;
912
913 // Wraps OrtApi::ExternalInitializerInfo_GetFilePath
914 const std::basic_string<ORTCHAR_T> GetFilePath() const;
915 // Wraps OrtApi::ExternalInitializerInfo_GetFileOffset
916 int64_t GetFileOffset() const;
917 // Wraps OrtApi::ExternalInitializerInfo_GetByteSize
918 size_t GetByteSize() const;
919};
920} // namespace detail
921
922// Const object holder that does not own the underlying object
925
931 using Base::Base;
932
933 explicit ExternalInitializerInfo(std::nullptr_t) {}
935 : detail::ConstExternalInitializerInfoImpl<OrtExternalInitializerInfo>{p} {}
936
938
940 ExternalInitializerInfo(const ORTCHAR_T* filepath, int64_t file_offset, size_t byte_size);
941
943 static Status Create(const ORTCHAR_T* filepath, int64_t file_offset, size_t byte_size,
944 /*out*/ ExternalInitializerInfo& out);
945};
946
947namespace detail {
948template <typename T>
951 using B::B;
952
953 const char* GetValue(const char* key) const;
954
955 // get the pairs in unordered_map. needs to copy to std::string so the hash works as expected
956 std::unordered_map<std::string, std::string> GetKeyValuePairs() const;
957 // get the pairs in two vectors. entries will be 1:1 between keys and values. avoids copying to std::string
958 void GetKeyValuePairs(std::vector<const char*>& keys, std::vector<const char*>& values) const;
959};
960} // namespace detail
961
962// Const object holder that does not own the underlying object
964
966struct KeyValuePairs : detail::KeyValuePairsImpl<OrtKeyValuePairs> {
967 explicit KeyValuePairs(std::nullptr_t) {}
969 explicit KeyValuePairs(OrtKeyValuePairs* p) : KeyValuePairsImpl<OrtKeyValuePairs>{p} {}
970
972 explicit KeyValuePairs();
973
975 explicit KeyValuePairs(const std::unordered_map<std::string, std::string>& kv_pairs);
976
978 void Add(const char* key, const char* value);
979
981 void Remove(const char* key);
982
983 ConstKeyValuePairs GetConst() const { return ConstKeyValuePairs{this->p_}; }
984};
985
986namespace detail {
987template <typename T>
988struct MemoryInfoImpl : Base<T> {
989 using B = Base<T>;
990 using B::B;
991
992 std::string GetAllocatorName() const;
994 int GetDeviceId() const;
998 uint32_t GetVendorId() const;
999
1000 template <typename U>
1001 bool operator==(const MemoryInfoImpl<U>& o) const;
1002};
1003} // namespace detail
1004
1005// Const object holder that does not own the underlying object
1007
1011struct MemoryInfo : detail::MemoryInfoImpl<OrtMemoryInfo> {
1013 explicit MemoryInfo(std::nullptr_t) {}
1014 explicit MemoryInfo(OrtMemoryInfo* p) : MemoryInfoImpl<OrtMemoryInfo>{p} {}
1015 MemoryInfo(const char* name, OrtAllocatorType type, int id, OrtMemType mem_type);
1016 MemoryInfo(const char* name, OrtMemoryInfoDeviceType device_type, uint32_t vendor_id, uint32_t device_id,
1017 OrtDeviceMemoryType mem_type, size_t alignment, OrtAllocatorType allocator_type);
1018 ConstMemoryInfo GetConst() const { return ConstMemoryInfo{this->p_}; }
1019};
1020
1028 MemoryAllocation(OrtAllocator* allocator, void* p, size_t size);
1033 MemoryAllocation& operator=(MemoryAllocation&&) noexcept;
1034
1035 void* get() { return p_; }
1036 size_t size() const { return size_; }
1037
1038 private:
1039 OrtAllocator* allocator_;
1040 void* p_;
1041 size_t size_;
1042};
1043
1044namespace detail {
1045template <typename T>
1046struct AllocatorImpl : Base<T> {
1047 using B = Base<T>;
1048 using B::B;
1049
1050 void* Alloc(size_t size);
1051 void* Reserve(size_t size);
1052 MemoryAllocation GetAllocation(size_t size);
1053 void Free(void* p);
1054 ConstMemoryInfo GetInfo() const;
1055
1060 KeyValuePairs GetStats() const;
1061
1066 void Shrink();
1067};
1068} // namespace detail
1069
1073struct AllocatorWithDefaultOptions : detail::AllocatorImpl<detail::Unowned<OrtAllocator>> {
1074 explicit AllocatorWithDefaultOptions(std::nullptr_t) {}
1076};
1077
1082struct Allocator : detail::AllocatorImpl<OrtAllocator> {
1083 explicit Allocator(std::nullptr_t) {}
1084 Allocator(const Session& session, const OrtMemoryInfo*);
1085
1087 explicit Allocator(OrtAllocator* p) : AllocatorImpl<OrtAllocator>{p} {}
1088};
1089
1090using UnownedAllocator = detail::AllocatorImpl<detail::Unowned<OrtAllocator>>;
1091
1096namespace detail {
1097template <typename T>
1099 using B = Base<T>;
1100 using B::B;
1101 // For some reason this is not a const method on the stream
1102 void* GetHandle();
1103};
1104} // namespace detail
1105
1106struct SyncStream : detail::SyncStreamImpl<OrtSyncStream> {
1108 explicit SyncStream(std::nullptr_t) {}
1110 explicit SyncStream(OrtSyncStream* p) : SyncStreamImpl<OrtSyncStream>{p} {}
1111};
1112
1114
1115namespace detail {
1116template <typename T>
1119 using B::B;
1120
1122 uint32_t VendorId() const;
1123 uint32_t DeviceId() const;
1124 const char* Vendor() const;
1126};
1127} // namespace detail
1128
1133
1134namespace detail {
1135template <typename T>
1138 using B::B;
1139
1140 const char* EpName() const;
1141 const char* EpVendor() const;
1147};
1148} // namespace detail
1149
1154
1157struct EpDevice : detail::EpDeviceImpl<OrtEpDevice> {
1158 explicit EpDevice(std::nullptr_t) {}
1159 explicit EpDevice(OrtEpDevice* p) : EpDeviceImpl<OrtEpDevice>{p} {}
1160
1162 EpDevice(OrtEpFactory& ep_factory, ConstHardwareDevice& hardware_device,
1163 ConstKeyValuePairs ep_metadata = {}, ConstKeyValuePairs ep_options = {});
1164};
1165
1173 const std::vector<ConstEpDevice>& ep_devices,
1174 const char* compatibility_info);
1175
1191AllocatedStringPtr GetCompatibilityInfoFromModelAllocated(const ORTCHAR_T* model_path, const char* ep_type,
1192 OrtAllocator* allocator);
1193
1206AllocatedStringPtr GetCompatibilityInfoFromModelBytesAllocated(const void* model_data, size_t model_data_length,
1207 const char* ep_type, OrtAllocator* allocator);
1208
1209namespace detail {
1210template <typename T>
1213 using B::B;
1214
1215 std::string GetName() const;
1216 std::string GetDomain() const;
1217 std::string GetOperatorType() const;
1218};
1219} // namespace detail
1220
1225
1226namespace detail {
1227template <typename T>
1230 using B::B;
1231
1232 std::string GetEpName() const;
1233 std::vector<ConstEpAssignedNode> GetNodes() const;
1234};
1235} // namespace detail
1236
1241
1242namespace detail {
1243template <typename T>
1246 using B::B;
1247
1249 OrtProfilingEventCategory GetCategory() const;
1250
1253 const char* GetName() const;
1254
1256 int64_t GetTimestampUs() const;
1257
1259 int64_t GetDurationUs() const;
1260
1264 const char* GetArgValue(const char* key) const;
1265};
1266} // namespace detail
1267
1276
1284 explicit ProfilingEvent(std::nullptr_t) {}
1286 : ConstProfilingEventImpl<OrtProfilingEvent>{p} {}
1287
1289 ProfilingEvent(OrtProfilingEventCategory category,
1290 int32_t process_id,
1291 int32_t thread_id,
1292 const char* event_name,
1293 int64_t timestamp_us,
1294 int64_t duration_us,
1295 const std::unordered_map<std::string, std::string>& args = {});
1296
1298 ProfilingEvent(OrtProfilingEventCategory category,
1299 int32_t process_id,
1300 int32_t thread_id,
1301 const char* event_name,
1302 int64_t timestamp_us,
1303 int64_t duration_us,
1304 const char* const* arg_keys,
1305 const char* const* arg_values,
1306 size_t num_args);
1307
1309};
1310
1311namespace detail {
1312template <typename T>
1315 using B::B;
1316
1319 Ort::Status AddEvents(const OrtProfilingEvent* const* events, size_t num_events);
1320 Ort::Status AddEvents(const std::vector<ProfilingEvent>& events);
1321};
1322} // namespace detail
1323
1330
1336struct Env : detail::Base<OrtEnv> {
1337 explicit Env(std::nullptr_t) {}
1338
1340 Env(OrtLoggingLevel logging_level = ORT_LOGGING_LEVEL_WARNING, _In_ const char* logid = "");
1341
1343 Env(OrtLoggingLevel logging_level, const char* logid, OrtLoggingFunction logging_function, void* logger_param);
1344
1346 Env(const OrtThreadingOptions* tp_options, OrtLoggingLevel logging_level = ORT_LOGGING_LEVEL_WARNING, _In_ const char* logid = "");
1347
1349 Env(const OrtThreadingOptions* tp_options, OrtLoggingFunction logging_function, void* logger_param,
1350 OrtLoggingLevel logging_level = ORT_LOGGING_LEVEL_WARNING, _In_ const char* logid = "");
1351
1353 explicit Env(const OrtEnvCreationOptions* options);
1354
1356 explicit Env(OrtEnv* p) : Base<OrtEnv>{p} {}
1357
1360
1362
1363 Env& CreateAndRegisterAllocator(const OrtMemoryInfo* mem_info, const OrtArenaCfg* arena_cfg);
1364
1365 Env& CreateAndRegisterAllocatorV2(const std::string& provider_type, const OrtMemoryInfo* mem_info,
1366 const std::unordered_map<std::string, std::string>& options,
1367 const OrtArenaCfg* arena_cfg);
1368
1370
1372
1374 OrtAllocatorType allocator_type,
1375 const OrtKeyValuePairs* allocator_options);
1376
1377 // Result may be nullptr
1379
1381 OrtDeviceMemoryType mem_type);
1382
1383 Env& RegisterExecutionProviderLibrary(const char* registration_name, const std::basic_string<ORTCHAR_T>& path);
1384 Env& UnregisterExecutionProviderLibrary(const char* registration_name);
1385
1386 std::vector<ConstEpDevice> GetEpDevices() const;
1387
1388 Status CopyTensors(const std::vector<Value>& src_tensors,
1389 const std::vector<Value>& dst_tensors,
1390 OrtSyncStream* stream) const;
1391
1394 Status CopyTensor(const OrtValue* src_tensor, OrtValue* dst_tensor, OrtSyncStream* stream) const;
1395
1401};
1402
1406struct CustomOpDomain : detail::Base<OrtCustomOpDomain> {
1408 using Base::Base;
1409
1410 explicit CustomOpDomain(std::nullptr_t) {}
1411
1413 explicit CustomOpDomain(const char* domain);
1414
1415 // This does not take ownership of the op, simply registers it.
1416 void Add(const OrtCustomOp* op);
1417};
1418
1420struct LoraAdapter : detail::Base<OrtLoraAdapter> {
1422 using Base::Base;
1423
1424 explicit LoraAdapter(std::nullptr_t) {}
1431 static LoraAdapter CreateLoraAdapter(const std::basic_string<ORTCHAR_T>& adapter_path,
1432 OrtAllocator* allocator);
1433
1441 static LoraAdapter CreateLoraAdapterFromArray(const void* bytes, size_t num_bytes,
1442 OrtAllocator* allocator);
1443};
1444
1448struct RunOptions : detail::Base<OrtRunOptions> {
1449 explicit RunOptions(std::nullptr_t) {}
1451
1454
1457
1458 RunOptions& SetRunTag(const char* run_tag);
1459 const char* GetRunTag() const;
1460
1461 RunOptions& AddConfigEntry(const char* config_key, const char* config_value);
1462 const char* GetConfigEntry(const char* config_key);
1463
1470
1476
1484
1493
1499 RunOptions& EnableProfiling(const ORTCHAR_T* profile_file_prefix);
1500
1506};
1507
1508namespace detail {
1509// Utility function that returns a SessionOption config entry key for a specific custom operator.
1510// Ex: custom_op.[custom_op_name].[config]
1511std::string MakeCustomOpConfigEntryKey(const char* custom_op_name, const char* config);
1512} // namespace detail
1513
1524 CustomOpConfigs() = default;
1525 ~CustomOpConfigs() = default;
1530
1539 CustomOpConfigs& AddConfig(const char* custom_op_name, const char* config_key, const char* config_value);
1540
1549 const std::unordered_map<std::string, std::string>& GetFlattenedConfigs() const;
1550
1551 private:
1552 std::unordered_map<std::string, std::string> flat_configs_;
1553};
1554
1560namespace detail {
1561// we separate const-only methods because passing const ptr to non-const methods
1562// is only discovered when inline methods are compiled which is counter-intuitive
1563template <typename T>
1564struct ConstSessionOptionsImpl : Base<T> {
1565 using B = Base<T>;
1566 using B::B;
1567
1568 SessionOptions Clone() const;
1569
1570 std::string GetConfigEntry(const char* config_key) const;
1571 bool HasConfigEntry(const char* config_key) const;
1572 std::string GetConfigEntryOrDefault(const char* config_key, const std::string& def) const;
1573};
1574
1575template <typename T>
1576struct SessionOptionsImpl : ConstSessionOptionsImpl<T> {
1577 using B = ConstSessionOptionsImpl<T>;
1578 using B::B;
1579
1580 SessionOptionsImpl& SetIntraOpNumThreads(int intra_op_num_threads);
1581 SessionOptionsImpl& SetInterOpNumThreads(int inter_op_num_threads);
1582 SessionOptionsImpl& SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level);
1583 SessionOptionsImpl& SetDeterministicCompute(bool value);
1584
1585 SessionOptionsImpl& EnableCpuMemArena();
1586 SessionOptionsImpl& DisableCpuMemArena();
1587
1588 SessionOptionsImpl& SetOptimizedModelFilePath(const ORTCHAR_T* optimized_model_file);
1589
1590 SessionOptionsImpl& EnableProfiling(const ORTCHAR_T* profile_file_prefix);
1591 SessionOptionsImpl& DisableProfiling();
1592
1593 SessionOptionsImpl& EnableOrtCustomOps();
1594
1595 SessionOptionsImpl& EnableMemPattern();
1596 SessionOptionsImpl& DisableMemPattern();
1597
1598 SessionOptionsImpl& SetExecutionMode(ExecutionMode execution_mode);
1599
1600 SessionOptionsImpl& SetLoadCancellationFlag(bool value);
1601
1602 SessionOptionsImpl& SetLogId(const char* logid);
1603 SessionOptionsImpl& SetLogSeverityLevel(int level);
1604
1605 SessionOptionsImpl& Add(OrtCustomOpDomain* custom_op_domain);
1606
1607 SessionOptionsImpl& DisablePerSessionThreads();
1608
1609 SessionOptionsImpl& AddConfigEntry(const char* config_key, const char* config_value);
1610
1611 SessionOptionsImpl& AddInitializer(const char* name, const OrtValue* ort_val);
1612 SessionOptionsImpl& AddExternalInitializers(const std::vector<std::string>& names, const std::vector<Value>& ort_values);
1613 SessionOptionsImpl& AddExternalInitializersFromFilesInMemory(const std::vector<std::basic_string<ORTCHAR_T>>& external_initializer_file_names,
1614 const std::vector<char*>& external_initializer_file_buffer_array,
1615 const std::vector<size_t>& external_initializer_file_lengths);
1616
1617 SessionOptionsImpl& AppendExecutionProvider_CPU(int use_arena);
1618 SessionOptionsImpl& AppendExecutionProvider_CUDA(const OrtCUDAProviderOptions& provider_options);
1619 SessionOptionsImpl& AppendExecutionProvider_CUDA_V2(const OrtCUDAProviderOptionsV2& provider_options);
1620 SessionOptionsImpl& AppendExecutionProvider_ROCM(const OrtROCMProviderOptions& provider_options);
1621 SessionOptionsImpl& AppendExecutionProvider_OpenVINO(const OrtOpenVINOProviderOptions& provider_options);
1623 SessionOptionsImpl& AppendExecutionProvider_OpenVINO_V2(const std::unordered_map<std::string, std::string>& provider_options = {});
1624 SessionOptionsImpl& AppendExecutionProvider_TensorRT(const OrtTensorRTProviderOptions& provider_options);
1625 SessionOptionsImpl& AppendExecutionProvider_TensorRT_V2(const OrtTensorRTProviderOptionsV2& provider_options);
1626 SessionOptionsImpl& AppendExecutionProvider_MIGraphX(const OrtMIGraphXProviderOptions& provider_options);
1628 SessionOptionsImpl& AppendExecutionProvider_CANN(const OrtCANNProviderOptions& provider_options);
1630 SessionOptionsImpl& AppendExecutionProvider_Dnnl(const OrtDnnlProviderOptions& provider_options);
1632 SessionOptionsImpl& AppendExecutionProvider(const std::string& provider_name,
1633 const std::unordered_map<std::string, std::string>& provider_options = {});
1634
1637 SessionOptionsImpl& AppendExecutionProvider_V2(Env& env, const std::vector<ConstEpDevice>& ep_devices,
1638 const KeyValuePairs& ep_options);
1641 SessionOptionsImpl& AppendExecutionProvider_V2(Env& env, const std::vector<ConstEpDevice>& ep_devices,
1642 const std::unordered_map<std::string, std::string>& ep_options);
1643
1645 SessionOptionsImpl& SetEpSelectionPolicy(OrtExecutionProviderDevicePolicy policy);
1646
1648 SessionOptionsImpl& SetEpSelectionPolicy(EpSelectionDelegate delegate, void* state = nullptr);
1649
1650 SessionOptionsImpl& SetCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn);
1651 SessionOptionsImpl& SetCustomThreadCreationOptions(void* ort_custom_thread_creation_options);
1652 SessionOptionsImpl& SetCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn);
1653
1657 SessionOptionsImpl& RegisterCustomOpsLibrary(const ORTCHAR_T* library_name, const CustomOpConfigs& custom_op_configs = {});
1658
1659 SessionOptionsImpl& RegisterCustomOpsUsingFunction(const char* function_name);
1660
1662 SessionOptionsImpl& AppendExecutionProvider_VitisAI(const std::unordered_map<std::string, std::string>& provider_options = {});
1663
1665 SessionOptionsImpl& AddFreeDimensionOverride(const char* dim_denotation, int64_t dim_value);
1666
1668 SessionOptionsImpl& AddFreeDimensionOverrideByName(const char* dim_name, int64_t dim_value);
1669};
1670} // namespace detail
1671
1672using UnownedSessionOptions = detail::SessionOptionsImpl<detail::Unowned<OrtSessionOptions>>;
1673using ConstSessionOptions = detail::ConstSessionOptionsImpl<detail::Unowned<const OrtSessionOptions>>;
1674
1678struct SessionOptions : detail::SessionOptionsImpl<OrtSessionOptions> {
1679 explicit SessionOptions(std::nullptr_t) {}
1681 explicit SessionOptions(OrtSessionOptions* p) : SessionOptionsImpl<OrtSessionOptions>{p} {}
1684};
1685
1690struct ModelCompilationOptions : detail::Base<OrtModelCompilationOptions> {
1692 using Base::Base;
1693
1694 explicit ModelCompilationOptions(std::nullptr_t) {}
1695
1696 ModelCompilationOptions(const Env& env, const SessionOptions& session_options);
1697 ModelCompilationOptions(const Env& env, ConstSessionOptions session_options);
1698
1699 ModelCompilationOptions& SetInputModelPath(const ORTCHAR_T* input_model_path);
1701 size_t input_model_data_size);
1702 ModelCompilationOptions& SetEpContextEmbedMode(bool embed_ep_context_in_model);
1703 ModelCompilationOptions& SetOutputModelPath(const ORTCHAR_T* output_model_path);
1705 size_t initializer_size_threshold);
1706
1709 OrtGetInitializerLocationFunc get_initializer_location_func,
1710 void* state);
1711
1712 ModelCompilationOptions& SetOutputModelBuffer(OrtAllocator* allocator, void** output_model_buffer_ptr,
1713 size_t* output_model_buffer_size_ptr);
1714
1717
1718 ModelCompilationOptions& SetEpContextBinaryInformation(const ORTCHAR_T* output_directory,
1719 const ORTCHAR_T* model_name);
1721
1723
1725};
1726
1733Status CompileModel(const Env& env, const ModelCompilationOptions& model_compilation_options);
1734
1738struct ModelMetadata : detail::Base<OrtModelMetadata> {
1740 using Base::Base;
1741
1742 explicit ModelMetadata(std::nullptr_t) {}
1743
1751
1759
1767
1775
1783
1790 std::vector<AllocatedStringPtr> GetCustomMetadataMapKeysAllocated(OrtAllocator* allocator) const;
1791
1802
1803 int64_t GetVersion() const;
1804};
1805
1806struct IoBinding;
1807
1808namespace detail {
1809
1810// we separate const-only methods because passing const ptr to non-const methods
1811// is only discovered when inline methods are compiled which is counter-intuitive
1812template <typename T>
1814 using B = Base<T>;
1815 using B::B;
1816
1817 size_t GetInputCount() const;
1818 size_t GetOutputCount() const;
1820
1821 std::vector<std::string> GetInputNames() const;
1822 std::vector<std::string> GetOutputNames() const;
1823 std::vector<std::string> GetOverridableInitializerNames() const;
1824
1825 std::vector<ConstMemoryInfo> GetMemoryInfoForInputs() const;
1826 std::vector<ConstMemoryInfo> GetMemoryInfoForOutputs() const;
1827 std::vector<ConstEpDevice> GetEpDeviceForInputs() const;
1828 std::vector<ConstEpDevice> GetEpDeviceForOutputs() const;
1829
1838
1847
1856
1857 uint64_t GetProfilingStartTimeNs() const;
1859
1860 TypeInfo GetInputTypeInfo(size_t index) const;
1861 TypeInfo GetOutputTypeInfo(size_t index) const;
1863
1864 int GetOpset(const std::string& domain) const;
1865
1866 std::vector<ValueInfo> GetInputs() const;
1867 std::vector<ValueInfo> GetOutputs() const;
1868
1873 std::vector<ConstEpAssignedSubgraph> GetEpGraphAssignmentInfo() const;
1874};
1875
1876template <typename T>
1879 using B::B;
1880
1898 std::vector<Value> Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
1899 const char* const* output_names, size_t output_count);
1900
1904 void Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
1905 const char* const* output_names, Value* output_values, size_t output_count);
1906
1907 void Run(const RunOptions& run_options, const IoBinding&);
1908
1928 void RunAsync(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
1929 const char* const* output_names, Value* output_values, size_t output_count, RunAsyncCallbackFn callback, void* user_data);
1930
1938
1950 void SetEpDynamicOptions(const char* const* keys, const char* const* values, size_t kv_len);
1951
1952 void FinalizeModelEditorSession(const Model& model, const SessionOptions& options,
1953 OrtPrepackedWeightsContainer* prepacked_weights_container = nullptr);
1954};
1955
1956} // namespace detail
1957
1960
1964struct Session : detail::SessionImpl<OrtSession> {
1966 explicit Session(std::nullptr_t) {}
1967 explicit Session(OrtSession* p) : SessionImpl{p} {}
1968
1969 Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options);
1970
1972 Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options,
1973 OrtPrepackedWeightsContainer* prepacked_weights_container);
1974
1976 Session(const Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options);
1977
1979 Session(const Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options,
1980 OrtPrepackedWeightsContainer* prepacked_weights_container);
1981
1982#if !defined(ORT_MINIMAL_BUILD)
1984 Session(const Env& env, const Model& model, const SessionOptions& options);
1985
1987 static Session CreateModelEditorSession(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options);
1988
1990 static Session CreateModelEditorSession(const Env& env, const void* model_data, size_t model_data_length,
1991 const SessionOptions& options);
1992#endif // !defined(ORT_MINIMAL_BUILD)
1993
1994 ConstSession GetConst() const { return ConstSession{this->p_}; }
1995 UnownedSession GetUnowned() const { return UnownedSession{this->p_}; }
1996};
1997
1998namespace detail {
1999template <typename T>
2001 using B = Base<T>;
2002 using B::B;
2003
2005 size_t GetElementCount() const;
2006
2007 size_t GetDimensionsCount() const;
2008
2013 [[deprecated("use GetShape()")]] void GetDimensions(int64_t* values, size_t values_count) const;
2014
2015 void GetSymbolicDimensions(const char** values, size_t values_count) const;
2016 std::vector<const char*> GetSymbolicDimensions() const;
2017
2018 bool HasShape() const;
2019 std::vector<int64_t> GetShape() const;
2020};
2021
2022} // namespace detail
2023
2025
2031 using Base::Base;
2032
2034 explicit TensorTypeAndShapeInfo(std::nullptr_t) {}
2036 explicit TensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* p) : TensorTypeAndShapeInfoImpl{p} {}
2037
2038 // Create a TensorTypeAndShapeInfo object with the specified element type and dimensions
2039 // symbolic_dims are optional, but should be 1:1 with dims.
2040 // The value in symbolic_dims will be used for all entries in dims that are -1.
2042 const std::vector<int64_t>& dims,
2043 const std::vector<std::string>* symbolic_dims = nullptr);
2044
2046};
2047
2048namespace detail {
2049template <typename T>
2051 using B = Base<T>;
2052 using B::B;
2054};
2055
2056} // namespace detail
2057
2059
2063struct SequenceTypeInfo : detail::SequenceTypeInfoImpl<OrtSequenceTypeInfo> {
2065 using Base::Base;
2066
2067 explicit SequenceTypeInfo(std::nullptr_t) {}
2068 explicit SequenceTypeInfo(OrtSequenceTypeInfo* p) : SequenceTypeInfoImpl<OrtSequenceTypeInfo>{p} {}
2070};
2071
2072namespace detail {
2073template <typename T>
2075 using B = Base<T>;
2076 using B::B;
2078};
2079
2080} // namespace detail
2081
2082// This is always owned by the TypeInfo and can only be obtained from it.
2084
2085namespace detail {
2086template <typename T>
2093
2094} // namespace detail
2095
2097
2101struct MapTypeInfo : detail::MapTypeInfoImpl<OrtMapTypeInfo> {
2103 using Base::Base;
2104
2105 explicit MapTypeInfo(std::nullptr_t) {}
2106 explicit MapTypeInfo(OrtMapTypeInfo* p) : MapTypeInfoImpl<OrtMapTypeInfo>{p} {}
2107 ConstMapTypeInfo GetConst() const { return ConstMapTypeInfo{this->p_}; }
2108};
2109
2110namespace detail {
2111template <typename T>
2123} // namespace detail
2124
2130
2135struct TypeInfo : detail::TypeInfoImpl<OrtTypeInfo> {
2137 using Base::Base;
2138
2140 explicit TypeInfo(std::nullptr_t) {}
2141 explicit TypeInfo(OrtTypeInfo* p) : TypeInfoImpl<OrtTypeInfo>{p} {}
2142
2143#if !defined(ORT_MINIMAL_BUILD)
2149#endif // !defined(ORT_MINIMAL_BUILD)
2150
2151 ConstTypeInfo GetConst() const { return ConstTypeInfo{this->p_}; }
2152};
2153
2154namespace detail {
2155// This structure is used to feed sparse tensor values
2156// information for use with FillSparseTensor<Format>() API
2157// if the data type for the sparse tensor values is numeric
2158// use data.p_data, otherwise, use data.str pointer to feed
2159// values. data.str is an array of const char* that are zero terminated.
2160// number of strings in the array must match shape size.
2161// For fully sparse tensors use shape {0} and set p_data/str
2162// to nullptr.
2164 const int64_t* values_shape;
2166 union {
2167 const void* p_data;
2168 const char** str;
2169 } data;
2170};
2171
2172// Provides a way to pass shape in a single
2173// argument
2174struct Shape {
2175 const int64_t* shape;
2177};
2178
2179template <typename T>
2181 using B = Base<T>;
2182 using B::B;
2183
2187 template <typename R>
2188 void GetOpaqueData(const char* domain, const char* type_name, R&) const;
2189
2190 bool IsTensor() const;
2191 bool HasValue() const;
2192
2193 size_t GetCount() const; // If a non tensor, returns 2 for map and N for sequence, where N is the number of elements
2194 Value GetValue(int index, OrtAllocator* allocator) const;
2195
2203
2218 void GetStringTensorContent(void* buffer, size_t buffer_length, size_t* offsets, size_t offsets_count) const;
2219
2226 template <typename R>
2227 const R* GetTensorData() const;
2228
2233 const void* GetTensorRawData() const;
2234
2242
2250
2256
2265 void GetStringTensorElement(size_t buffer_length, size_t element_index, void* buffer) const;
2266
2273 std::string GetStringTensorElement(size_t element_index) const;
2274
2281 size_t GetStringTensorElementLength(size_t element_index) const;
2282
2289 size_t GetTensorSizeInBytes() const;
2290
2291#if !defined(DISABLE_SPARSE_TENSORS)
2299
2306
2315
2325 template <typename R>
2326 const R* GetSparseTensorIndicesData(OrtSparseIndicesFormat indices_format, size_t& num_indices) const;
2327
2332 bool IsSparseTensor() const;
2333
2342 template <typename R>
2343 const R* GetSparseTensorValues() const;
2344
2345#endif
2346
2359};
2360
2361template <typename T>
2364 using B::B;
2365
2371 template <typename R>
2373
2379
2381 // Obtain a reference to an element of data at the location specified
2387 template <typename R>
2388 R& At(const std::vector<int64_t>& location);
2389
2395 void FillStringTensor(const char* const* s, size_t s_len);
2396
2402 void FillStringTensorElement(const char* s, size_t index);
2403
2416 char* GetResizedStringTensorElementBuffer(size_t index, size_t buffer_length);
2417
2418#if !defined(DISABLE_SPARSE_TENSORS)
2427 void UseCooIndices(int64_t* indices_data, size_t indices_num);
2428
2439 void UseCsrIndices(int64_t* inner_data, size_t inner_num, int64_t* outer_data, size_t outer_num);
2440
2449 void UseBlockSparseIndices(const Shape& indices_shape, int32_t* indices_data);
2450
2460 void FillSparseTensorCoo(const OrtMemoryInfo* data_mem_info, const OrtSparseValuesParam& values_param,
2461 const int64_t* indices_data, size_t indices_num);
2462
2474 void FillSparseTensorCsr(const OrtMemoryInfo* data_mem_info,
2475 const OrtSparseValuesParam& values,
2476 const int64_t* inner_indices_data, size_t inner_indices_num,
2477 const int64_t* outer_indices_data, size_t outer_indices_num);
2478
2489 const OrtSparseValuesParam& values,
2490 const Shape& indices_shape,
2491 const int32_t* indices_data);
2492
2493#endif
2494};
2495
2496} // namespace detail
2497
2500
2504struct Value : detail::ValueImpl<OrtValue> {
2506 using Base::Base;
2509
2510 Value(std::nullptr_t) {}
2511 Value(Value&&) = default;
2512 Value& operator=(Value&&) = default;
2513
2514 ConstValue GetConst() const { return ConstValue{this->p_}; }
2515 UnownedValue GetUnowned() const { return UnownedValue{this->p_}; }
2516
2525 template <typename T>
2526 static Value CreateTensor(const OrtMemoryInfo* info, T* p_data, size_t p_data_element_count,
2527 const int64_t* shape, size_t shape_len);
2528
2538 static Value CreateTensor(const OrtMemoryInfo* info, void* p_data, size_t p_data_byte_count,
2539 const int64_t* shape, size_t shape_len,
2541
2551 static Value CreateTensor(OrtAllocator* deleter, void* p_data, size_t p_data_byte_count,
2552 const int64_t* shape, size_t shape_len,
2554
2566 template <typename T>
2567 static Value CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len);
2568
2580 static Value CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len,
2582
2591 static Value CreateMap(const Value& keys, const Value& values);
2592
2600 static Value CreateSequence(const std::vector<Value>& values);
2601
2610 template <typename T>
2611 static Value CreateOpaque(const char* domain, const char* type_name, const T& value);
2612
2613#if !defined(DISABLE_SPARSE_TENSORS)
2624 template <typename T>
2625 static Value CreateSparseTensor(const OrtMemoryInfo* info, T* p_data, const Shape& dense_shape,
2626 const Shape& values_shape);
2627
2644 static Value CreateSparseTensor(const OrtMemoryInfo* info, void* p_data, const Shape& dense_shape,
2645 const Shape& values_shape, ONNXTensorElementDataType type);
2646
2656 template <typename T>
2657 static Value CreateSparseTensor(OrtAllocator* allocator, const Shape& dense_shape);
2658
2670 static Value CreateSparseTensor(OrtAllocator* allocator, const Shape& dense_shape, ONNXTensorElementDataType type);
2671
2672#endif // !defined(DISABLE_SPARSE_TENSORS)
2673};
2674
2675namespace detail {
2676namespace binding_utils {
2677// Bring these out of template
2678std::vector<std::string> GetOutputNamesHelper(const OrtIoBinding* binding, OrtAllocator*);
2679std::vector<Value> GetOutputValuesHelper(const OrtIoBinding* binding, OrtAllocator*);
2680} // namespace binding_utils
2681
2682template <typename T>
2684 using B = Base<T>;
2685 using B::B;
2686
2687 std::vector<std::string> GetOutputNames() const;
2688 std::vector<std::string> GetOutputNames(OrtAllocator*) const;
2689 std::vector<Value> GetOutputValues() const;
2690 std::vector<Value> GetOutputValues(OrtAllocator*) const;
2691};
2692
2693template <typename T>
2696 using B::B;
2697
2698 void BindInput(const char* name, const Value&);
2699 void BindOutput(const char* name, const Value&);
2700 void BindOutput(const char* name, const OrtMemoryInfo*);
2705};
2706
2707} // namespace detail
2708
2711
2715struct IoBinding : detail::IoBindingImpl<OrtIoBinding> {
2716 explicit IoBinding(std::nullptr_t) {}
2717 explicit IoBinding(Session& session);
2718 ConstIoBinding GetConst() const { return ConstIoBinding{this->p_}; }
2719 UnownedIoBinding GetUnowned() const { return UnownedIoBinding{this->p_}; }
2720};
2721
2726struct ArenaCfg : detail::Base<OrtArenaCfg> {
2727 explicit ArenaCfg(std::nullptr_t) {}
2736 ArenaCfg(size_t max_mem, int arena_extend_strategy, int initial_chunk_size_bytes, int max_dead_bytes_per_chunk);
2737
2742 explicit ArenaCfg(const std::unordered_map<std::string, size_t>& arena_config);
2743};
2744
2745//
2746// Custom OPs (only needed to implement custom OPs)
2747//
2748
2749namespace detail {
2750// Need to define a templated ConstOpAttr with const members
2751template <typename T>
2754 using B::B;
2755
2756 // Wraps OrtApi::OpAttr_GetName
2757 std::string GetName() const;
2758 // Wraps OrtApi::OpAttr_GetType
2760
2761 // Wraps OrtApi::ReadAttr for a single value
2762 // This does not support Tensor Attribute
2763 // Call GetTensorAttributeAsOrtValue() instead.
2764 template <typename R>
2765 Status GetValue(R& out) const;
2766
2767 // Wraps OrtApi::ReadAttr for an array of values
2768 template <typename R>
2769 Status GetValueArray(std::vector<R>& out) const;
2770 // Wraps OrtApi::OpAttr_GetTensorAttributeAsOrtValue
2772};
2773} // namespace detail
2774
2776
2780struct OpAttr : detail::ConstOpAttrImpl<OrtOpAttr> {
2782 using Base::Base;
2783
2784 OpAttr() = default; // Enable storing it in the container for resize()
2785 explicit OpAttr(std::nullptr_t) {}
2786 OpAttr(const char* name, const void* data, int len, OrtOpAttrType type);
2787
2788 ConstOpAttr GetConst() const { return ConstOpAttr{this->p_}; }
2789};
2790
2799#define ORT_CXX_LOG(logger, message_severity, message) \
2800 do { \
2801 if (message_severity >= logger.GetLoggingSeverityLevel()) { \
2802 Ort::ThrowOnError(logger.LogMessage(message_severity, ORT_FILE, __LINE__, \
2803 static_cast<const char*>(__FUNCTION__), message)); \
2804 } \
2805 } while (false)
2806
2815#define ORT_CXX_LOG_NOEXCEPT(logger, message_severity, message) \
2816 do { \
2817 if (message_severity >= logger.GetLoggingSeverityLevel()) { \
2818 static_cast<void>(logger.LogMessage(message_severity, ORT_FILE, __LINE__, \
2819 static_cast<const char*>(__FUNCTION__), message)); \
2820 } \
2821 } while (false)
2822
2834#define ORT_CXX_LOGF(logger, message_severity, /*format,*/...) \
2835 do { \
2836 if (message_severity >= logger.GetLoggingSeverityLevel()) { \
2837 Ort::ThrowOnError(logger.LogFormattedMessage(message_severity, ORT_FILE, __LINE__, \
2838 static_cast<const char*>(__FUNCTION__), __VA_ARGS__)); \
2839 } \
2840 } while (false)
2841
2853#define ORT_CXX_LOGF_NOEXCEPT(logger, message_severity, /*format,*/...) \
2854 do { \
2855 if (message_severity >= logger.GetLoggingSeverityLevel()) { \
2856 static_cast<void>(logger.LogFormattedMessage(message_severity, ORT_FILE, __LINE__, \
2857 static_cast<const char*>(__FUNCTION__), __VA_ARGS__)); \
2858 } \
2859 } while (false)
2860
2871struct Logger {
2875 Logger() = default;
2876
2880 explicit Logger(std::nullptr_t) {}
2881
2888 explicit Logger(const OrtLogger* logger);
2889
2890 ~Logger() = default;
2891
2892 Logger(const Logger&) = default;
2893 Logger& operator=(const Logger&) = default;
2894
2895 Logger(Logger&& v) noexcept = default;
2896 Logger& operator=(Logger&& v) noexcept = default;
2897
2904
2917 Status LogMessage(OrtLoggingLevel log_severity_level, const ORTCHAR_T* file_path, int line_number,
2918 const char* func_name, const char* message) const noexcept;
2919
2934 template <typename... Args>
2935 Status LogFormattedMessage(OrtLoggingLevel log_severity_level, const ORTCHAR_T* file_path, int line_number,
2936 const char* func_name, const char* format, Args&&... args) const noexcept;
2937
2938 private:
2939 const OrtLogger* logger_{};
2940 OrtLoggingLevel cached_severity_level_{};
2941};
2942
2951 size_t GetInputCount() const;
2952 size_t GetOutputCount() const;
2953 // If input is optional and is not present, the method returns an empty ConstValue
2954 // which can be compared to nullptr.
2955 ConstValue GetInput(size_t index) const;
2956 // If output is optional and is not present, the method returns an empty UnownedValue
2957 // which can be compared to nullptr.
2958 UnownedValue GetOutput(size_t index, const int64_t* dim_values, size_t dim_count) const;
2959 UnownedValue GetOutput(size_t index, const std::vector<int64_t>& dims) const;
2960 void* GetGPUComputeStream() const;
2962 Ort::Allocator GetAllocator(const OrtMemoryInfo& memory_info) const;
2963 OrtKernelContext* GetOrtKernelContext() const { return ctx_; }
2964 void ParallelFor(void (*fn)(void*, size_t), size_t total, size_t num_batch, void* usr_data) const;
2965
2966 private:
2967 OrtKernelContext* ctx_;
2968};
2969
2970struct KernelInfo;
2971
2972namespace detail {
2973namespace attr_utils {
2974void GetAttr(const OrtKernelInfo* p, const char* name, float&);
2975void GetAttr(const OrtKernelInfo* p, const char* name, int64_t&);
2976void GetAttr(const OrtKernelInfo* p, const char* name, std::string&);
2977void GetAttrs(const OrtKernelInfo* p, const char* name, std::vector<float>&);
2978void GetAttrs(const OrtKernelInfo* p, const char* name, std::vector<int64_t>&);
2979void GetAttrs(const OrtKernelInfo* p, const char* name, std::vector<std::string>&);
2980} // namespace attr_utils
2981
2982template <typename T>
2983struct KernelInfoImpl : Base<T> {
2984 using B = Base<T>;
2985 using B::B;
2986
2987 KernelInfo Copy() const;
2988
2989 template <typename R> // R is only implemented for float, int64_t, and string
2990 R GetAttribute(const char* name) const {
2991 R val;
2992 attr_utils::GetAttr(this->p_, name, val);
2993 return val;
2994 }
2995
2996 template <typename R> // R is only implemented for float, int64_t, and string
2997 std::vector<R> GetAttributes(const char* name) const {
2998 std::vector<R> result;
2999 attr_utils::GetAttrs(this->p_, name, result);
3000 return result;
3001 }
3002
3003 Value GetTensorAttribute(const char* name, OrtAllocator* allocator) const;
3004
3005 size_t GetInputCount() const;
3006 size_t GetOutputCount() const;
3007
3008 std::string GetInputName(size_t index) const;
3009 std::string GetOutputName(size_t index) const;
3010
3011 TypeInfo GetInputTypeInfo(size_t index) const;
3012 TypeInfo GetOutputTypeInfo(size_t index) const;
3013
3014 ConstValue GetTensorConstantInput(size_t index, int* is_constant) const;
3015
3016 std::string GetNodeName() const;
3017 Logger GetLogger() const;
3018
3019 KeyValuePairs GetConfigEntries() const;
3020
3021 std::string GetOperatorDomain() const;
3022 std::string GetOperatorType() const;
3023 int GetOperatorSinceVersion() const;
3024 const OrtEp* GetEp() const;
3025};
3026
3027} // namespace detail
3028
3029using ConstKernelInfo = detail::KernelInfoImpl<detail::Unowned<const OrtKernelInfo>>;
3030
3037struct KernelInfo : detail::KernelInfoImpl<OrtKernelInfo> {
3038 using Base = detail::KernelInfoImpl<OrtKernelInfo>;
3039 using Base::Base;
3040 explicit KernelInfo(std::nullptr_t) {}
3041 explicit KernelInfo(OrtKernelInfo* info);
3042 ConstKernelInfo GetConst() const { return ConstKernelInfo{this->p_}; }
3043};
3044
3048struct Op : detail::Base<OrtOp> {
3050 using Base::Base;
3051
3052 explicit Op(std::nullptr_t) {}
3053
3054 explicit Op(OrtOp*);
3055
3056 static Op Create(const OrtKernelInfo* info, const char* op_name, const char* domain,
3057 int version, const char** type_constraint_names,
3058 const ONNXTensorElementDataType* type_constraint_values,
3059 size_t type_constraint_count,
3060 const OpAttr* attr_values,
3061 size_t attr_count,
3062 size_t input_count, size_t output_count);
3063
3064 void Invoke(const OrtKernelContext* context,
3065 const Value* input_values,
3066 size_t input_count,
3067 Value* output_values,
3068 size_t output_count);
3069
3070 // For easier refactoring
3071 void Invoke(const OrtKernelContext* context,
3072 const OrtValue* const* input_values,
3073 size_t input_count,
3074 OrtValue* const* output_values,
3075 size_t output_count);
3076};
3077
3083 SymbolicInteger(int64_t i) : i_(i), is_int_(true) {};
3084 SymbolicInteger(const char* s) : s_(s), is_int_(false) {};
3087
3090
3091 bool operator==(const SymbolicInteger& dim) const {
3092 if (is_int_ == dim.is_int_) {
3093 if (is_int_) {
3094 return i_ == dim.i_;
3095 } else {
3096 return std::string{s_} == std::string{dim.s_};
3097 }
3098 }
3099 return false;
3100 }
3101
3102 bool IsInt() const { return is_int_; }
3103 int64_t AsInt() const { return i_; }
3104 const char* AsSym() const { return s_; }
3105
3106 static constexpr int INVALID_INT_DIM = -2;
3107
3108 private:
3109 union {
3110 int64_t i_;
3111 const char* s_;
3112 };
3113 bool is_int_;
3114 };
3115
3116 using Shape = std::vector<SymbolicInteger>;
3117
3119
3120 const Shape& GetInputShape(size_t indice) const { return input_shapes_.at(indice); }
3121
3122 size_t GetInputCount() const { return input_shapes_.size(); }
3123
3125
3126 int64_t GetAttrInt(const char* attr_name);
3127
3128 using Ints = std::vector<int64_t>;
3129 Ints GetAttrInts(const char* attr_name);
3130
3131 float GetAttrFloat(const char* attr_name);
3132
3133 using Floats = std::vector<float>;
3134 Floats GetAttrFloats(const char* attr_name);
3135
3136 std::string GetAttrString(const char* attr_name);
3137
3138 using Strings = std::vector<std::string>;
3139 Strings GetAttrStrings(const char* attr_name);
3140
3141 private:
3142 ConstOpAttr GetAttrHdl(const char* attr_name) const;
3143 const OrtApi* ort_api_;
3145 std::vector<Shape> input_shapes_;
3146};
3147
3149
3150#define MAX_CUSTOM_OP_END_VER (1UL << 31) - 1
3151
3152template <typename TOp, typename TKernel, bool WithStatus = false>
3156 OrtCustomOp::GetName = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetName(); };
3157
3158 OrtCustomOp::GetExecutionProviderType = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetExecutionProviderType(); };
3159
3160 OrtCustomOp::GetInputTypeCount = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetInputTypeCount(); };
3161 OrtCustomOp::GetInputType = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetInputType(index); };
3162 OrtCustomOp::GetInputMemoryType = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetInputMemoryType(index); };
3163
3164 OrtCustomOp::GetOutputTypeCount = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetOutputTypeCount(); };
3165 OrtCustomOp::GetOutputType = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetOutputType(index); };
3166
3167#if defined(_MSC_VER) && !defined(__clang__)
3168#pragma warning(push)
3169#pragma warning(disable : 26409)
3170#endif
3171 OrtCustomOp::KernelDestroy = [](void* op_kernel) { delete static_cast<TKernel*>(op_kernel); };
3172#if defined(_MSC_VER) && !defined(__clang__)
3173#pragma warning(pop)
3174#endif
3175 OrtCustomOp::GetInputCharacteristic = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetInputCharacteristic(index); };
3176 OrtCustomOp::GetOutputCharacteristic = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetOutputCharacteristic(index); };
3177
3178 OrtCustomOp::GetVariadicInputMinArity = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetVariadicInputMinArity(); };
3179 OrtCustomOp::GetVariadicInputHomogeneity = [](const OrtCustomOp* this_) { return static_cast<int>(static_cast<const TOp*>(this_)->GetVariadicInputHomogeneity()); };
3180 OrtCustomOp::GetVariadicOutputMinArity = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetVariadicOutputMinArity(); };
3181 OrtCustomOp::GetVariadicOutputHomogeneity = [](const OrtCustomOp* this_) { return static_cast<int>(static_cast<const TOp*>(this_)->GetVariadicOutputHomogeneity()); };
3182#ifdef __cpp_if_constexpr
3183 if constexpr (WithStatus) {
3184#else
3185 if (WithStatus) {
3186#endif
3187 OrtCustomOp::CreateKernelV2 = [](const OrtCustomOp* this_, const OrtApi* api, const OrtKernelInfo* info, void** op_kernel) -> OrtStatusPtr {
3188 return static_cast<const TOp*>(this_)->CreateKernelV2(*api, info, op_kernel);
3189 };
3190 OrtCustomOp::KernelComputeV2 = [](void* op_kernel, OrtKernelContext* context) -> OrtStatusPtr {
3191 return static_cast<TKernel*>(op_kernel)->ComputeV2(context);
3192 };
3193 } else {
3196
3197 OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* api, const OrtKernelInfo* info) { return static_cast<const TOp*>(this_)->CreateKernel(*api, info); };
3198 OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) {
3199 static_cast<TKernel*>(op_kernel)->Compute(context);
3200 };
3201 }
3202
3203 SetShapeInferFn<TOp>(0);
3204
3205 OrtCustomOp::GetStartVersion = [](const OrtCustomOp* this_) {
3206 return static_cast<const TOp*>(this_)->start_ver_;
3207 };
3208
3209 OrtCustomOp::GetEndVersion = [](const OrtCustomOp* this_) {
3210 return static_cast<const TOp*>(this_)->end_ver_;
3211 };
3212
3215 OrtCustomOp::GetAliasMap = nullptr;
3217 }
3218
3219 // Default implementation of GetExecutionProviderType that returns nullptr to default to the CPU provider
3220 const char* GetExecutionProviderType() const { return nullptr; }
3221
3222 // Default implementations of GetInputCharacteristic() and GetOutputCharacteristic() below
3223 // (inputs and outputs are required by default)
3225 return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED;
3226 }
3227
3229 return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED;
3230 }
3231
3232 // Default implementation of GetInputMemoryType() that returns OrtMemTypeDefault
3233 OrtMemType GetInputMemoryType(size_t /*index*/) const {
3234 return OrtMemTypeDefault;
3235 }
3236
3237 // Default implementation of GetVariadicInputMinArity() returns 1 to specify that a variadic input
3238 // should expect at least 1 argument.
3240 return 1;
3241 }
3242
3243 // Default implementation of GetVariadicInputHomegeneity() returns true to specify that all arguments
3244 // to a variadic input should be of the same type.
3246 return true;
3247 }
3248
3249 // Default implementation of GetVariadicOutputMinArity() returns 1 to specify that a variadic output
3250 // should produce at least 1 output value.
3252 return 1;
3253 }
3254
3255 // Default implementation of GetVariadicOutputHomegeneity() returns true to specify that all output values
3256 // produced by a variadic output should be of the same type.
3258 return true;
3259 }
3260
3261 // Declare list of session config entries used by this Custom Op.
3262 // Implement this function in order to get configs from CustomOpBase::GetSessionConfigs().
3263 // This default implementation returns an empty vector of config entries.
3264 std::vector<std::string> GetSessionConfigKeys() const {
3265 return std::vector<std::string>{};
3266 }
3267
3268 // Ort::CustomOpBase derived class should provide the following static method with the type/shape inferencing
3269 // implementation if needed:
3270 // static OrtStatusPtr InferOutputShape(Ort::ShapeInferContext& context)
3271 template <typename C>
3272 decltype(&C::InferOutputShape) SetShapeInferFn(decltype(&C::InferOutputShape)) {
3274 ShapeInferContext ctx(&GetApi(), ort_ctx);
3275 return C::InferOutputShape(ctx);
3276 };
3277 return {};
3278 }
3279
3280 template <typename C>
3284
3285 protected:
3286 // Helper function that returns a map of session config entries specified by CustomOpBase::GetSessionConfigKeys.
3287 void GetSessionConfigs(std::unordered_map<std::string, std::string>& out, ConstSessionOptions options) const;
3288
3289 int start_ver_ = 1;
3290 int end_ver_ = MAX_CUSTOM_OP_END_VER;
3291};
3292
3293// Forward declaration to resolve circular dependency
3294// on ConstNode
3296
3297namespace detail {
3298template <typename T>
3300 using B = Base<T>;
3301 using B::B;
3302
3304 std::string GetName() const;
3310 std::vector<ValueInfoConsumerProducerInfo> GetConsumers() const;
3320 bool IsGraphOutput() const;
3324 bool IsFromOuterScope() const;
3325};
3326} // namespace detail
3327
3328// Const object holder that does not own the underlying object
3330
3335 ValueInfo() = default; // Same thing as with nullptr
3336 explicit ValueInfo(std::nullptr_t) {}
3338 explicit ValueInfo(OrtValueInfo* p) : ConstValueInfoImpl<OrtValueInfo>{p} {}
3339
3340#if !defined(ORT_MINIMAL_BUILD)
3341 // Create ValueInfo for a tensor
3342 explicit ValueInfo(const std::string& name, const ConstTypeInfo& type_info);
3343#endif
3344 ConstValueInfo GetConst() const { return ConstValueInfo{this->p_}; }
3345};
3346
3347// Forward declaration
3348struct AttrNameSubgraph;
3349
3350namespace detail {
3351// Forward decl
3352template <typename T>
3353struct ConstGraphImpl;
3354
3355template <typename T>
3356struct ConstNodeImpl : Base<T> {
3357 using B = Base<T>;
3358 using B::B;
3359
3360 // <Wraps OrtApi::Node_GetId
3361 size_t GetId() const;
3362 // <Wraps OrtApi::Node_GetName
3363 std::string GetName() const;
3364 // <Wraps OrtApi::Node_GetOperatorType
3365 std::string GetOperatorType() const;
3366 // <Wraps OrtApi::Node_GetDomain
3367 std::string GetDomain() const;
3368 // <Wraps OrtApi::Node_GetSinceVersion
3369 int GetSinceVersion() const;
3370
3371 // <Wraps OrtApi::Node_Inputs
3372 std::vector<ConstValueInfo> GetInputs() const;
3373 // <Wraps OrtApi::Node_Outputs
3374 std::vector<ConstValueInfo> GetOutputs() const;
3375 // <Wraps OrtApi::Node_ImplicitInputs
3376 std::vector<ConstValueInfo> GetImplicitInputs() const;
3377 // <Wraps OrtApi::Node_GetAttributes
3378 std::vector<ConstOpAttr> GetAttributes() const;
3379 // <Wraps OrtApi::Node_GetAttributeByName
3380 // Please, read C API doc for details
3381 Status GetAttributeByName(const std::string& name, ConstOpAttr& attr) const;
3382 // <Wraps OrtApi::Node_GetSubgraphs
3383 std::vector<AttrNameSubgraph> GetSubgraphs() const;
3384 // <Wraps OrtApi::Node_GetGraph
3385 // ConstGraph is not available yet
3387 // <Wraps OrtApi::Node_GetEpName
3388 std::string GetEpName() const;
3389};
3390} // namespace detail
3391
3393
3397struct Node : detail::ConstNodeImpl<OrtNode> {
3398 Node() = default; // Same thing as with nullptr
3399 explicit Node(std::nullptr_t) {}
3400 explicit Node(OrtNode* p) : ConstNodeImpl<OrtNode>{p} {}
3401
3402#if !defined(ORT_MINIMAL_BUILD)
3403 Node(const std::string& operator_name, const std::string& operator_domain,
3404 const std::string& node_name,
3405 const std::vector<std::string>& input_names,
3406 const std::vector<std::string>& output_names);
3407
3411 Node(const std::string& operator_name, const std::string& operator_domain,
3412 const std::string& node_name,
3413 const std::vector<std::string>& input_names,
3414 const std::vector<std::string>& output_names,
3415 std::vector<OpAttr>& attributes);
3416
3417 private:
3418 static void Init(const std::string& operator_name, const std::string& operator_domain,
3419 const std::string& node_name,
3420 const std::vector<std::string>& input_names,
3421 const std::vector<std::string>& output_names,
3422 std::vector<OpAttr>& attributes,
3423 OrtNode*& node);
3424#endif // !defined(ORT_MINIMAL_BUILD)
3425};
3426
3427// Return struct for some of ValueInfo APIs.
3428// Must be declared after ConstNode is available.
3431 // either producer output or consumer output index
3432 // producer is unsigned only, output can be -1
3433 int64_t index;
3434};
3435
3436// Represents a return value for Graph::GetOperatorSets()
3438 std::string domain;
3439 int64_t version;
3440};
3441
3442namespace detail {
3443template <typename T>
3445 using B = Base<T>;
3446 using B::B;
3447
3448 // <Wraps OrtApi::Graph_GetName
3449 std::string GetName() const;
3450 // <Wraps OrtApi::Graph_GetModelPath
3451 std::basic_string<ORTCHAR_T> GetModelPath() const;
3452 // <Wraps OrtApi::Graph_GetOnnxIRVersion
3453 int64_t GetOnnxIRVersion() const;
3454 // <Wraps OrtApi::Graph_GetOperatorSets
3455 std::vector<OperatorSet> GetOperatorSets() const;
3456 // <Wraps OrtApi::Graph_Inputs
3457 std::vector<ConstValueInfo> GetInputs() const;
3458 // <Wraps OrtApi::Graph_Outputs
3459 std::vector<ConstValueInfo> GetOutputs() const;
3460 // <Wraps OrtApi::Graph_Initializers
3461 std::vector<ConstValueInfo> GetInitializers() const;
3462 // <Wraps OrtApi::Graph_GetNodes
3463 std::vector<ConstNode> GetNodes() const;
3464 // <Wraps OrtApi::Graph_GetParentGraph
3466 // <Wraps OrtApi::Graph_GetGraphView
3467 Graph GetGraphView(const std::vector<ConstNode>& nodes) const;
3468 // <Wraps OrtApi::Graph_GetModelMetadata
3470};
3471
3472template <typename T>
3475 using B::B;
3476
3477#if !defined(ORT_MINIMAL_BUILD)
3478 // <Wraps GetModelEditorApi().SetGraphInputs()
3479 void SetInputs(std::vector<ValueInfo>& inputs);
3480 // <Wraps GetModelEditorApi().SetGraphOutputs()
3481 void SetOutputs(std::vector<ValueInfo>& outputs);
3482 // <Wraps GetModelEditorApi().AddInitializerToGraph()
3483 void AddInitializer(const std::string& name, Value& initializer, bool data_is_external); // Graph takes ownership of Value
3484 // <Wraps GetModelEditorApi().AddNodeToGraph()
3485 void AddNode(Node& node); // Graph takes ownership of Node
3486#endif // !defined(ORT_MINIMAL_BUILD)
3487};
3488} // namespace detail
3489
3491
3492// Return value for Node API
3493// Must be declared after ConstGraph
3498
3502struct Graph : detail::GraphImpl<OrtGraph> {
3503 explicit Graph(std::nullptr_t) {}
3504 explicit Graph(OrtGraph* p) : GraphImpl<OrtGraph>{p} {}
3505#if !defined(ORT_MINIMAL_BUILD)
3506 // <Wraps GetModelEditorApi().CreateGraph()
3508#endif
3509};
3510
3511namespace detail {
3512template <typename T>
3515 using B::B;
3516
3517#if !defined(ORT_MINIMAL_BUILD)
3518 // <Wraps GetModelEditorApi().AddGraphToModel()
3519 void AddGraph(Graph& graph);
3520#endif
3521};
3522} // namespace detail
3523
3524// Const object holder that does not own the underlying object
3526
3530struct Model : detail::ModelImpl<OrtModel> {
3531 using DomainOpsetPair = std::pair<std::string, int>;
3532
3533 explicit Model(std::nullptr_t) {}
3534 explicit Model(OrtModel* p) : ModelImpl<OrtModel>{p} {}
3535
3536#if !defined(ORT_MINIMAL_BUILD)
3537 //< Wraps GetModelEditorApi().CreateModel()
3538 explicit Model(const std::vector<DomainOpsetPair>& opsets);
3539#endif
3540};
3541
3542namespace detail {
3543template <typename T>
3545 using B = Base<T>;
3546 using B::B;
3547
3549 const char* GetOperatorType() const;
3550
3552 const char* GetDomain() const;
3553
3555 std::pair<int, int> GetSinceVersion() const;
3556
3558 const char* GetExecutionProvider() const;
3559
3561 OrtMemType GetInputMemType(size_t input_index) const;
3562
3564 OrtMemType GetOutputMemType(size_t output_index) const;
3565};
3566} // namespace detail
3567
3569
3572 using Base::Base;
3573
3574 explicit KernelDef(std::nullptr_t) {}
3575 explicit KernelDef(OrtKernelDef* p) : detail::ConstKernelDefImpl<OrtKernelDef>{p} {}
3576
3577 ConstKernelDef GetConst() const { return ConstKernelDef{this->p_}; }
3578};
3579
3584struct KernelDefBuilder : detail::Base<OrtKernelDefBuilder> {
3586 explicit KernelDefBuilder(std::nullptr_t) {}
3587 explicit KernelDefBuilder(OrtKernelDefBuilder* ort_kernel_def_builder);
3588
3589 KernelDefBuilder& SetOperatorType(const char* op_type);
3590 KernelDefBuilder& SetDomain(const char* domain);
3591 KernelDefBuilder& SetSinceVersion(int since_version_start, int since_version_end);
3593 KernelDefBuilder& SetInputMemType(size_t input_index, OrtMemType mem_type);
3594 KernelDefBuilder& SetOutputMemType(size_t output_index, OrtMemType mem_type);
3595 KernelDefBuilder& AddTypeConstraint(const char* arg_name, const OrtDataType* data_type);
3596 KernelDefBuilder& AddTypeConstraint(const char* arg_name, const std::vector<const OrtDataType*>& data_types);
3597 KernelDefBuilder& AddInputOutputAlias(int input_index, int output_index);
3598 KernelDefBuilder& AddInputOutputAliases(const std::vector<int>& input_indices,
3599 const std::vector<int>& output_indices);
3600 KernelDefBuilder& AddInputOutputMutableAlias(int input_index, int output_index);
3601 KernelDefBuilder& AddInputOutputMutableAliases(const std::vector<int>& input_indices,
3602 const std::vector<int>& output_indices);
3603
3605};
3606
3611struct KernelRegistry : detail::Base<OrtKernelRegistry> {
3614
3616 explicit KernelRegistry(std::nullptr_t) {}
3617
3619 explicit KernelRegistry(OrtKernelRegistry* ort_kernel_registry);
3620
3622 Status AddKernel(const OrtKernelDef* kernel_def, OrtKernelCreateFunc kernel_create_func,
3623 void* kernel_create_func_state);
3624};
3625
3626namespace detail {
3627template <typename T>
3629 using B = Base<T>;
3630 using B::B;
3631
3633 std::string GetTypeParamName() const;
3634
3636 std::vector<std::string> GetAllowedTypes() const;
3637
3639 std::vector<size_t> GetInputIndices() const;
3640
3642 std::vector<size_t> GetOutputIndices() const;
3643};
3644} // namespace detail
3645
3650
3651namespace detail {
3652template <typename T>
3653struct OpSchemaImpl : Base<T> {
3654 using B = Base<T>;
3655 using B::B;
3656
3658 int GetSinceVersion() const;
3659
3661 size_t GetNumInputs() const;
3662
3664 std::string GetInputName(size_t index) const;
3665
3669
3671 size_t GetNumOutputs() const;
3672
3674 std::string GetOutputName(size_t index) const;
3675
3679
3682
3685};
3686} // namespace detail
3687
3693
3699OpSchema GetOpSchema(const char* name, int max_inclusive_version, const char* domain);
3700
3701namespace detail {
3702template <typename T>
3705 using B::B;
3706
3707 //< Wraps SharedPrePackedWeightCache_StoreWeightData
3708 Status StoreWeightData(void** buffer_data_ptrs, size_t* buffer_sizes, size_t num_buffers);
3709};
3710} // namespace detail
3711
3729
3732
3733} // namespace Ort
3734#include "onnxruntime_cxx_inline.h"
struct OrtMemoryInfo OrtMemoryInfo
Definition onnxruntime_c_api.h:299
struct OrtKernelInfo OrtKernelInfo
Definition onnxruntime_c_api.h:475
struct OrtNode OrtNode
Definition onnxruntime_c_api.h:327
OrtLoggingLevel
Logging severity levels.
Definition onnxruntime_c_api.h:249
OrtMemoryInfoDeviceType
This mimics OrtDevice type constants so they can be returned in the API.
Definition onnxruntime_c_api.h:510
struct OrtShapeInferContext OrtShapeInferContext
Definition onnxruntime_c_api.h:324
void(* OrtLoggingFunction)(void *param, OrtLoggingLevel severity, const char *category, const char *logid, const char *code_location, const char *message)
Definition onnxruntime_c_api.h:439
void(* OrtCustomJoinThreadFn)(OrtCustomThreadHandle ort_custom_thread_handle)
Custom thread join function.
Definition onnxruntime_c_api.h:976
OrtCustomOpInputOutputCharacteristic
Definition onnxruntime_c_api.h:7456
struct OrtTensorRTProviderOptionsV2 OrtTensorRTProviderOptionsV2
Definition onnxruntime_c_api.h:316
struct OrtThreadingOptions OrtThreadingOptions
Definition onnxruntime_c_api.h:313
struct OrtSequenceTypeInfo OrtSequenceTypeInfo
Definition onnxruntime_c_api.h:307
struct OrtValueInfo OrtValueInfo
Definition onnxruntime_c_api.h:326
struct OrtDnnlProviderOptions OrtDnnlProviderOptions
Definition onnxruntime_c_api.h:320
OrtSparseIndicesFormat
Definition onnxruntime_c_api.h:238
struct OrtPrepackedWeightsContainer OrtPrepackedWeightsContainer
Definition onnxruntime_c_api.h:315
struct OrtSession OrtSession
Definition onnxruntime_c_api.h:301
OrtCompiledModelCompatibility
Definition onnxruntime_c_api.h:1189
OrtStatus *(* EpSelectionDelegate)(const OrtEpDevice **ep_devices, size_t num_devices, const OrtKeyValuePairs *model_metadata, const OrtKeyValuePairs *runtime_metadata, const OrtEpDevice **selected, size_t max_selected, size_t *num_selected, void *state)
Delegate to allow providing custom OrtEpDevice selection logic.
Definition onnxruntime_c_api.h:564
struct OrtCustomOpDomain OrtCustomOpDomain
Definition onnxruntime_c_api.h:310
struct OrtIoBinding OrtIoBinding
Definition onnxruntime_c_api.h:300
struct OrtExternalInitializerInfo OrtExternalInitializerInfo
Definition onnxruntime_c_api.h:335
OrtAllocatorType
Definition onnxruntime_c_api.h:481
struct OrtOp OrtOp
Definition onnxruntime_c_api.h:321
struct OrtTypeInfo OrtTypeInfo
Definition onnxruntime_c_api.h:304
struct OrtTensorTypeAndShapeInfo OrtTensorTypeAndShapeInfo
Definition onnxruntime_c_api.h:305
struct OrtCUDAProviderOptionsV2 OrtCUDAProviderOptionsV2
Definition onnxruntime_c_api.h:318
struct OrtProfilingEvent OrtProfilingEvent
Definition onnxruntime_ep_c_api.h:35
struct OrtKernelContext OrtKernelContext
Definition onnxruntime_c_api.h:477
struct OrtCANNProviderOptions OrtCANNProviderOptions
Definition onnxruntime_c_api.h:319
struct OrtEpDevice OrtEpDevice
Definition onnxruntime_c_api.h:332
void(* RunAsyncCallbackFn)(void *user_data, OrtValue **outputs, size_t num_outputs, OrtStatusPtr status)
Callback function for RunAsync.
Definition onnxruntime_c_api.h:1050
OrtHardwareDeviceType
Definition onnxruntime_c_api.h:517
struct OrtModel OrtModel
Definition onnxruntime_c_api.h:329
struct OrtGraph OrtGraph
Definition onnxruntime_c_api.h:328
struct OrtSyncStream OrtSyncStream
Definition onnxruntime_c_api.h:334
struct OrtSessionOptions OrtSessionOptions
Definition onnxruntime_c_api.h:309
OrtDeviceMemoryType
This matches OrtDevice::MemoryType values.
Definition onnxruntime_c_api.h:503
struct OrtValue OrtValue
Definition onnxruntime_c_api.h:302
OrtStatus *(* OrtWriteBufferFunc)(void *state, const void *buffer, size_t buffer_num_bytes)
Function called by ORT to write a buffer to a custom destination (e.g., file, stream,...
Definition onnxruntime_c_api.h:583
GraphOptimizationLevel
Graph optimization level.
Definition onnxruntime_c_api.h:448
struct OrtKeyValuePairs OrtKeyValuePairs
Definition onnxruntime_c_api.h:333
OrtStatus * OrtStatusPtr
Definition onnxruntime_c_api.h:346
OrtMemType
Memory types for allocated memory, execution provider specific types should be extended in each provi...
Definition onnxruntime_c_api.h:491
OrtSparseFormat
Definition onnxruntime_c_api.h:230
ONNXType
Definition onnxruntime_c_api.h:218
struct OrtEnv OrtEnv
Definition onnxruntime_c_api.h:297
OrtErrorCode
Definition onnxruntime_c_api.h:257
struct OrtStatus OrtStatus
Definition onnxruntime_c_api.h:298
OrtStatus *(* OrtGetInitializerLocationFunc)(void *state, const char *initializer_name, const OrtValue *initializer_value, const OrtExternalInitializerInfo *external_info, OrtExternalInitializerInfo **new_external_info)
Function called by ORT to allow user to specify how an initializer should be saved,...
Definition onnxruntime_c_api.h:617
#define ORT_API_VERSION
The API version defined in this header.
Definition onnxruntime_c_api.h:41
struct OrtLogger OrtLogger
Definition onnxruntime_c_api.h:323
struct OrtMapTypeInfo OrtMapTypeInfo
Definition onnxruntime_c_api.h:306
struct OrtArenaCfg OrtArenaCfg
Definition onnxruntime_c_api.h:314
ExecutionMode
Definition onnxruntime_c_api.h:456
OrtOpAttrType
Definition onnxruntime_c_api.h:275
OrtCustomThreadHandle(* OrtCustomCreateThreadFn)(void *ort_custom_thread_creation_options, OrtThreadWorkerFn ort_thread_worker_fn, void *ort_worker_fn_param)
Ort custom thread creation function.
Definition onnxruntime_c_api.h:969
ONNXTensorElementDataType
Definition onnxruntime_c_api.h:184
OrtExecutionProviderDevicePolicy
These are the default EP selection policies used by ORT when doing automatic EP selection.
Definition onnxruntime_c_api.h:525
const OrtApiBase * OrtGetApiBase(void)
The Onnxruntime library's entry point to access the C API.
@ ORT_LOGGING_LEVEL_WARNING
Warning messages.
Definition onnxruntime_c_api.h:252
@ OrtMemTypeDefault
The default allocator for execution provider.
Definition onnxruntime_c_api.h:499
@ ORT_FAIL
Definition onnxruntime_c_api.h:259
@ ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT
Definition onnxruntime_c_api.h:186
std::vector< Value > GetOutputValuesHelper(const OrtIoBinding *binding, OrtAllocator *)
std::vector< std::string > GetOutputNamesHelper(const OrtIoBinding *binding, OrtAllocator *)
void OrtRelease(OrtAllocator *ptr)
Definition onnxruntime_cxx_api.h:629
std::string MakeCustomOpConfigEntryKey(const char *custom_op_name, const char *config)
All C++ Onnxruntime APIs are defined inside this namespace.
Definition onnxruntime_cxx_api.h:48
Ort::KeyValuePairs GetEnvConfigEntries()
const OrtModelEditorApi & GetModelEditorApi()
This returns a reference to the ORT C Model Editor API. Used if building or augmenting a model at run...
Definition onnxruntime_cxx_api.h:215
std::unique_ptr< char, detail::AllocatedFree > AllocatedStringPtr
unique_ptr typedef used to own strings allocated by OrtAllocators and release them at the end of the ...
Definition onnxruntime_cxx_api.h:807
detail::ConstSessionOptionsImpl< detail::Unowned< const OrtSessionOptions > > ConstSessionOptions
Definition onnxruntime_cxx_api.h:1673
detail::KernelInfoImpl< detail::Unowned< const OrtKernelInfo > > ConstKernelInfo
Definition onnxruntime_cxx_api.h:3029
const OrtApi & GetApi() noexcept
This returns a reference to the ORT C API.
Definition onnxruntime_cxx_api.h:189
const OrtCompileApi & GetCompileApi()
This returns a reference to the ORT C Compile API. Used if compiling a model at runtime.
Definition onnxruntime_cxx_api.h:229
AllocatedStringPtr GetCompatibilityInfoFromModelAllocated(const char *model_path, const char *ep_type, OrtAllocator *allocator)
Extract EP compatibility info from a precompiled model file.
AllocatedStringPtr GetCompatibilityInfoFromModelBytesAllocated(const void *model_data, size_t model_data_length, const char *ep_type, OrtAllocator *allocator)
Extract EP compatibility info from precompiled model bytes in memory.
detail::AllocatorImpl< detail::Unowned< OrtAllocator > > UnownedAllocator
Definition onnxruntime_cxx_api.h:1090
OrtCompiledModelCompatibility GetModelCompatibilityForEpDevices(const std::vector< ConstEpDevice > &ep_devices, const char *compatibility_info)
Validate a compiled model's compatibility for one or more EP devices.
const OrtInteropApi & GetInteropApi()
This returns a reference to the ORT C Interop API. Used for external resource import with EPs.
Definition onnxruntime_cxx_api.h:243
OpSchema GetOpSchema(const char *name, int max_inclusive_version, const char *domain)
Get an operator schema from the global schema registry.
detail::SessionOptionsImpl< detail::Unowned< OrtSessionOptions > > UnownedSessionOptions
Definition onnxruntime_cxx_api.h:1672
std::string GetBuildInfoString()
This function returns the onnxruntime build information: including git branch, git commit id,...
const OrtEpApi & GetEpApi()
This returns a reference to the ORT C EP API. Used if authoring a plugin execution provider.
Definition onnxruntime_cxx_api.h:257
std::string GetVersionString()
This function returns the onnxruntime version string.
std::vector< std::string > GetAvailableProviders()
This is a C++ wrapper for OrtApi::GetAvailableProviders() and returns a vector of strings representin...
Ort::Status(*)(Ort::ShapeInferContext &) ShapeInferFn
Definition onnxruntime_cxx_api.h:3148
Status CompileModel(const Env &env, const ModelCompilationOptions &model_compilation_options)
Compiles an input model to generate a model with EPContext nodes that execute EP-specific kernels....
Wrapper around OrtAllocator.
Definition onnxruntime_cxx_api.h:1082
Allocator(const Session &session, const OrtMemoryInfo *)
Take ownership of a pointer created by C API.
Allocator(std::nullptr_t)
Convenience to create a class member and then replace with an instance.
Definition onnxruntime_cxx_api.h:1083
Allocator(OrtAllocator *p)
Definition onnxruntime_cxx_api.h:1087
Wrapper around OrtAllocator default instance that is owned by Onnxruntime.
Definition onnxruntime_cxx_api.h:1073
AllocatorWithDefaultOptions(std::nullptr_t)
Convenience to create a class member and then replace with an instance.
Definition onnxruntime_cxx_api.h:1074
it is a structure that represents the configuration of an arena based allocator
Definition onnxruntime_cxx_api.h:2726
ArenaCfg(std::nullptr_t)
Create an empty ArenaCfg object, must be assigned a valid one to be used.
Definition onnxruntime_cxx_api.h:2727
ArenaCfg(const std::unordered_map< std::string, size_t > &arena_config)
ArenaCfg(size_t max_mem, int arena_extend_strategy, int initial_chunk_size_bytes, int max_dead_bytes_per_chunk)
Definition onnxruntime_cxx_api.h:3494
ConstGraph sub_graph
Definition onnxruntime_cxx_api.h:3496
std::string attr_name
Definition onnxruntime_cxx_api.h:3495
bfloat16 (Brain Floating Point) data type
Definition onnxruntime_cxx_api.h:427
bool operator==(const BFloat16_t &rhs) const noexcept
onnxruntime_float16::BFloat16Impl< BFloat16_t > Base
Definition onnxruntime_cxx_api.h:439
BFloat16_t()=default
static constexpr BFloat16_t FromBits(uint16_t v) noexcept
Explicit conversion to uint16_t representation of bfloat16.
Definition onnxruntime_cxx_api.h:448
bool operator!=(const BFloat16_t &rhs) const noexcept
Definition onnxruntime_cxx_api.h:546
BFloat16_t(float v) noexcept
__ctor from float. Float is converted into bfloat16 16-bit representation.
Definition onnxruntime_cxx_api.h:454
float ToFloat() const noexcept
Converts bfloat16 to float.
Definition onnxruntime_cxx_api.h:460
bool operator<(const BFloat16_t &rhs) const noexcept
The CUDAProviderOptions (V2)
Definition onnxruntime_cxx_api.h:878
CUDAProviderOptions()
Wraps OrtApi::CreateCUDAProviderOptions.
CUDAProviderOptions(std::nullptr_t)
Definition onnxruntime_cxx_api.h:879
void UpdateWithValue(const char *key, void *value)
Wrapper around OrtApi::GetCUDAProviderOptionsByName.
std::string GetCUDAProviderOptionsAsString() const
Wrapper around OrtApi::UpdateCUDAProviderOptionsWithValue.
void Update(const std::unordered_map< std::string, std::string > &options)
Wrapper around OrtApi::GetCUDAProviderOptionsAsString.
void * GetOptionByName(const char *name) const
Definition onnxruntime_cxx_api.h:3153
OrtCustomOpInputOutputCharacteristic GetOutputCharacteristic(size_t) const
Definition onnxruntime_cxx_api.h:3228
OrtCustomOpInputOutputCharacteristic GetInputCharacteristic(size_t) const
Definition onnxruntime_cxx_api.h:3224
OrtMemType GetInputMemoryType(size_t) const
Definition onnxruntime_cxx_api.h:3233
std::vector< std::string > GetSessionConfigKeys() const
Definition onnxruntime_cxx_api.h:3264
bool GetVariadicInputHomogeneity() const
Definition onnxruntime_cxx_api.h:3245
int GetVariadicInputMinArity() const
Definition onnxruntime_cxx_api.h:3239
void SetShapeInferFn(...)
Definition onnxruntime_cxx_api.h:3281
CustomOpBase()
Definition onnxruntime_cxx_api.h:3154
bool GetVariadicOutputHomogeneity() const
Definition onnxruntime_cxx_api.h:3257
int GetVariadicOutputMinArity() const
Definition onnxruntime_cxx_api.h:3251
decltype(&C::InferOutputShape) SetShapeInferFn(decltype(&C::InferOutputShape))
Definition onnxruntime_cxx_api.h:3272
const char * GetExecutionProviderType() const
Definition onnxruntime_cxx_api.h:3220
void GetSessionConfigs(std::unordered_map< std::string, std::string > &out, ConstSessionOptions options) const
Class that represents session configuration entries for one or more custom operators.
Definition onnxruntime_cxx_api.h:1523
~CustomOpConfigs()=default
CustomOpConfigs & AddConfig(const char *custom_op_name, const char *config_key, const char *config_value)
Adds a session configuration entry/value for a specific custom operator.
CustomOpConfigs & operator=(CustomOpConfigs &&o)=default
CustomOpConfigs(CustomOpConfigs &&o)=default
CustomOpConfigs()=default
const std::unordered_map< std::string, std::string > & GetFlattenedConfigs() const
Returns a flattened map of custom operator configuration entries and their values.
CustomOpConfigs(const CustomOpConfigs &)=default
CustomOpConfigs & operator=(const CustomOpConfigs &)=default
Custom Op Domain.
Definition onnxruntime_cxx_api.h:1406
CustomOpDomain(std::nullptr_t)
Create an empty CustomOpDomain object, must be assigned a valid one to be used.
Definition onnxruntime_cxx_api.h:1410
CustomOpDomain(const char *domain)
Wraps OrtApi::CreateCustomOpDomain.
void Add(const OrtCustomOp *op)
Wraps CustomOpDomain_Add.
The Env (Environment)
Definition onnxruntime_cxx_api.h:1336
Env & EnableTelemetryEvents()
Wraps OrtApi::EnableTelemetryEvents.
Env(OrtEnv *p)
C Interop Helper.
Definition onnxruntime_cxx_api.h:1356
Env & CreateAndRegisterAllocatorV2(const std::string &provider_type, const OrtMemoryInfo *mem_info, const std::unordered_map< std::string, std::string > &options, const OrtArenaCfg *arena_cfg)
Wraps OrtApi::CreateAndRegisterAllocatorV2.
Env & UnregisterExecutionProviderLibrary(const char *registration_name)
Wraps OrtApi::UnregisterExecutionProviderLibrary.
Env & SetPerSessionThreadPoolCallbacks(const OrtThreadPoolCallbacksConfig &config)
Wraps OrtApi::SetPerSessionThreadPoolCallbacks Stores work callbacks on the Env for per-session threa...
std::vector< ConstEpDevice > GetEpDevices() const
Env & UnregisterAllocator(const OrtMemoryInfo *mem_info)
Wraps OrtApi::UnregisterAllocator.
Env(std::nullptr_t)
Create an empty Env object, must be assigned a valid one to be used.
Definition onnxruntime_cxx_api.h:1337
Status CopyTensor(const OrtValue *src_tensor, OrtValue *dst_tensor, OrtSyncStream *stream) const
Env(OrtLoggingLevel logging_level=ORT_LOGGING_LEVEL_WARNING, const char *logid="")
Wraps OrtApi::CreateEnv.
Env(const OrtThreadingOptions *tp_options, OrtLoggingLevel logging_level=ORT_LOGGING_LEVEL_WARNING, const char *logid="")
Wraps OrtApi::CreateEnvWithGlobalThreadPools.
Env(const OrtThreadingOptions *tp_options, OrtLoggingFunction logging_function, void *logger_param, OrtLoggingLevel logging_level=ORT_LOGGING_LEVEL_WARNING, const char *logid="")
Wraps OrtApi::CreateEnvWithCustomLoggerAndGlobalThreadPools.
Env & RegisterAllocator(OrtAllocator *allocator)
Wraps OrtApi::RegisterAllocator.
UnownedAllocator CreateSharedAllocator(const OrtEpDevice *ep_device, OrtDeviceMemoryType mem_type, OrtAllocatorType allocator_type, const OrtKeyValuePairs *allocator_options)
Wraps OrtApi::CreateSharedAllocator.
Env(OrtLoggingLevel logging_level, const char *logid, OrtLoggingFunction logging_function, void *logger_param)
Wraps OrtApi::CreateEnvWithCustomLogger.
Env(const OrtEnvCreationOptions *options)
Wraps OrtApi::CreateEnvWithOptions.
Env & CreateAndRegisterAllocator(const OrtMemoryInfo *mem_info, const OrtArenaCfg *arena_cfg)
Wraps OrtApi::CreateAndRegisterAllocator.
UnownedAllocator GetSharedAllocator(const OrtMemoryInfo *mem_info)
Wraps OrtApi::GetSharedAllocator.
Env & RegisterExecutionProviderLibrary(const char *registration_name, const std::basic_string< char > &path)
Wraps OrtApi::RegisterExecutionProviderLibrary.
Env & UpdateEnvWithCustomLogLevel(OrtLoggingLevel log_severity_level)
Wraps OrtApi::UpdateEnvWithCustomLogLevel.
Status CopyTensors(const std::vector< Value > &src_tensors, const std::vector< Value > &dst_tensors, OrtSyncStream *stream) const
Wraps OrtApi::CopyTensors.
void ReleaseSharedAllocator(const OrtEpDevice *ep_device, OrtDeviceMemoryType mem_type)
Wraps OrtApi::ReleaseSharedAllocator.
Env & DisableTelemetryEvents()
Wraps OrtApi::DisableTelemetryEvents.
Mutable EpDevice that is created by EpApi users.
Definition onnxruntime_cxx_api.h:1157
EpDevice(OrtEpDevice *p)
Take ownership of a pointer created by C API.
Definition onnxruntime_cxx_api.h:1159
EpDevice(OrtEpFactory &ep_factory, ConstHardwareDevice &hardware_device, ConstKeyValuePairs ep_metadata={}, ConstKeyValuePairs ep_options={})
Wraps OrtEpApi::CreateEpDevice.
EpDevice(std::nullptr_t)
No instance is created.
Definition onnxruntime_cxx_api.h:1158
All C++ methods that can fail will throw an exception of this type.
Definition onnxruntime_cxx_api.h:54
const char * what() const noexcept override
Definition onnxruntime_cxx_api.h:59
Exception(const std::string &string, OrtErrorCode code)
Definition onnxruntime_cxx_api.h:55
OrtErrorCode GetOrtErrorCode() const
Definition onnxruntime_cxx_api.h:58
Exception(std::string &&string, OrtErrorCode code)
Definition onnxruntime_cxx_api.h:56
Wrapper around OrtExternalInitializerInfo.
Definition onnxruntime_cxx_api.h:929
ConstExternalInitializerInfo GetConst() const
Wraps OrtApi::CreateExternalInitializerInfo.
Definition onnxruntime_cxx_api.h:937
ExternalInitializerInfo(const char *filepath, int64_t file_offset, size_t byte_size)
Wrapper around CreateExternalInitializerInfo that does not throw an exception.
ExternalInitializerInfo(std::nullptr_t)
Definition onnxruntime_cxx_api.h:933
ExternalInitializerInfo(OrtExternalInitializerInfo *p)
Definition onnxruntime_cxx_api.h:934
static Status Create(const char *filepath, int64_t file_offset, size_t byte_size, ExternalInitializerInfo &out)
IEEE 754 half-precision floating point data type.
Definition onnxruntime_cxx_api.h:285
Float16_t()=default
Default constructor.
Float16_t(float v) noexcept
__ctor from float. Float is converted into float16 16-bit representation.
Definition onnxruntime_cxx_api.h:313
onnxruntime_float16::Float16Impl< Float16_t > Base
Definition onnxruntime_cxx_api.h:295
float ToFloat() const noexcept
Converts float16 to float.
Definition onnxruntime_cxx_api.h:319
static constexpr Float16_t FromBits(uint16_t v) noexcept
Explicit conversion to uint16_t representation of float16.
Definition onnxruntime_cxx_api.h:307
float8e4m3fn (Float8 Floating Point) data type
Definition onnxruntime_cxx_api.h:557
uint8_t value
Definition onnxruntime_cxx_api.h:558
constexpr Float8E4M3FN_t(uint8_t v) noexcept
Definition onnxruntime_cxx_api.h:560
constexpr bool operator==(const Float8E4M3FN_t &rhs) const noexcept
Definition onnxruntime_cxx_api.h:563
constexpr Float8E4M3FN_t() noexcept
Definition onnxruntime_cxx_api.h:559
constexpr bool operator!=(const Float8E4M3FN_t &rhs) const noexcept
Definition onnxruntime_cxx_api.h:564
float8e4m3fnuz (Float8 Floating Point) data type
Definition onnxruntime_cxx_api.h:574
constexpr bool operator==(const Float8E4M3FNUZ_t &rhs) const noexcept
Definition onnxruntime_cxx_api.h:580
uint8_t value
Definition onnxruntime_cxx_api.h:575
constexpr Float8E4M3FNUZ_t() noexcept
Definition onnxruntime_cxx_api.h:576
constexpr bool operator!=(const Float8E4M3FNUZ_t &rhs) const noexcept
Definition onnxruntime_cxx_api.h:581
constexpr Float8E4M3FNUZ_t(uint8_t v) noexcept
Definition onnxruntime_cxx_api.h:577
float8e5m2 (Float8 Floating Point) data type
Definition onnxruntime_cxx_api.h:591
constexpr Float8E5M2_t(uint8_t v) noexcept
Definition onnxruntime_cxx_api.h:594
uint8_t value
Definition onnxruntime_cxx_api.h:592
constexpr bool operator!=(const Float8E5M2_t &rhs) const noexcept
Definition onnxruntime_cxx_api.h:598
constexpr Float8E5M2_t() noexcept
Definition onnxruntime_cxx_api.h:593
constexpr bool operator==(const Float8E5M2_t &rhs) const noexcept
Definition onnxruntime_cxx_api.h:597
float8e5m2fnuz (Float8 Floating Point) data type
Definition onnxruntime_cxx_api.h:608
constexpr Float8E5M2FNUZ_t() noexcept
Definition onnxruntime_cxx_api.h:610
constexpr Float8E5M2FNUZ_t(uint8_t v) noexcept
Definition onnxruntime_cxx_api.h:611
constexpr bool operator!=(const Float8E5M2FNUZ_t &rhs) const noexcept
Definition onnxruntime_cxx_api.h:615
constexpr bool operator==(const Float8E5M2FNUZ_t &rhs) const noexcept
Definition onnxruntime_cxx_api.h:614
uint8_t value
Definition onnxruntime_cxx_api.h:609
Wrapper around OrtGraph.
Definition onnxruntime_cxx_api.h:3502
Graph(OrtGraph *p)
Take ownership of a pointer created by C API.
Definition onnxruntime_cxx_api.h:3504
Graph(std::nullptr_t)
No instance is created.
Definition onnxruntime_cxx_api.h:3503
Wrapper around OrtIoBinding.
Definition onnxruntime_cxx_api.h:2715
UnownedIoBinding GetUnowned() const
Definition onnxruntime_cxx_api.h:2719
ConstIoBinding GetConst() const
Definition onnxruntime_cxx_api.h:2718
IoBinding(Session &session)
IoBinding(std::nullptr_t)
Create an empty object for convenience. Sometimes, we want to initialize members later.
Definition onnxruntime_cxx_api.h:2716
This class wraps a raw pointer OrtKernelContext* that is being passed to the custom kernel Compute() ...
Definition onnxruntime_cxx_api.h:2949
KernelContext(OrtKernelContext *context)
Logger GetLogger() const
ConstValue GetInput(size_t index) const
OrtKernelContext * GetOrtKernelContext() const
Definition onnxruntime_cxx_api.h:2963
void ParallelFor(void(*fn)(void *, size_t), size_t total, size_t num_batch, void *usr_data) const
void * GetGPUComputeStream() const
size_t GetInputCount() const
Ort::Allocator GetAllocator(const OrtMemoryInfo &memory_info) const
size_t GetOutputCount() const
UnownedValue GetOutput(size_t index, const std::vector< int64_t > &dims) const
UnownedValue GetOutput(size_t index, const int64_t *dim_values, size_t dim_count) const
Builder for OrtKernelDef.
Definition onnxruntime_cxx_api.h:3584
KernelDefBuilder & AddTypeConstraint(const char *arg_name, const OrtDataType *data_type)
KernelDefBuilder & SetOutputMemType(size_t output_index, OrtMemType mem_type)
KernelDefBuilder & AddInputOutputMutableAliases(const std::vector< int > &input_indices, const std::vector< int > &output_indices)
KernelDefBuilder & SetInputMemType(size_t input_index, OrtMemType mem_type)
KernelDefBuilder & SetDomain(const char *domain)
KernelDefBuilder & AddInputOutputAliases(const std::vector< int > &input_indices, const std::vector< int > &output_indices)
KernelDefBuilder & AddInputOutputAlias(int input_index, int output_index)
KernelDefBuilder & SetExecutionProvider(const char *ep_name)
KernelDefBuilder & SetOperatorType(const char *op_type)
KernelDefBuilder & AddInputOutputMutableAlias(int input_index, int output_index)
KernelDefBuilder()
Wraps OrtEpApi::CreateKernelDefBuilder.
KernelDefBuilder & AddTypeConstraint(const char *arg_name, const std::vector< const OrtDataType * > &data_types)
KernelDefBuilder(std::nullptr_t)
Create an empty object, must be assigned a valid one to be used.
Definition onnxruntime_cxx_api.h:3586
KernelDefBuilder(OrtKernelDefBuilder *ort_kernel_def_builder)
KernelDefBuilder & SetSinceVersion(int since_version_start, int since_version_end)
Definition onnxruntime_cxx_api.h:3570
KernelDef(OrtKernelDef *p)
Definition onnxruntime_cxx_api.h:3575
KernelDef(std::nullptr_t)
Definition onnxruntime_cxx_api.h:3574
ConstKernelDef GetConst() const
Definition onnxruntime_cxx_api.h:3577
This struct owns the OrtKernInfo* pointer when a copy is made. For convenient wrapping of OrtKernelIn...
Definition onnxruntime_cxx_api.h:3037
KernelInfo(OrtKernelInfo *info)
Take ownership of the instance.
ConstKernelInfo GetConst() const
Definition onnxruntime_cxx_api.h:3042
detail::KernelInfoImpl< OrtKernelInfo > Base
Definition onnxruntime_cxx_api.h:3038
KernelInfo(std::nullptr_t)
Create an empty instance to initialize later.
Definition onnxruntime_cxx_api.h:3040
Registry for kernels supported by an EP.
Definition onnxruntime_cxx_api.h:3611
KernelRegistry()
< Wrapper around OrtEpApi::CreateKernelRegistry
KernelRegistry(std::nullptr_t)
Take ownership of a pointer created with the C API.
Definition onnxruntime_cxx_api.h:3616
Status AddKernel(const OrtKernelDef *kernel_def, OrtKernelCreateFunc kernel_create_func, void *kernel_create_func_state)
KernelRegistry(OrtKernelRegistry *ort_kernel_registry)
Wraps KernelRegistry_AddKernel.
Wrapper around OrtKeyValuePairs.
Definition onnxruntime_cxx_api.h:966
KeyValuePairs()
Wraps OrtApi::CreateKeyValuePairs.
void Add(const char *key, const char *value)
Wraps OrtApi::AddKeyValuePair.
KeyValuePairs(const std::unordered_map< std::string, std::string > &kv_pairs)
Wraps OrtApi::CreateKeyValuePairs and OrtApi::AddKeyValuePair.
void Remove(const char *key)
Wraps OrtApi::RemoveKeyValuePair.
KeyValuePairs(std::nullptr_t)
Definition onnxruntime_cxx_api.h:967
ConstKeyValuePairs GetConst() const
Definition onnxruntime_cxx_api.h:983
KeyValuePairs(OrtKeyValuePairs *p)
Take ownership of a pointer created by C API.
Definition onnxruntime_cxx_api.h:969
This class represents an ONNX Runtime logger that can be used to log information with an associated s...
Definition onnxruntime_cxx_api.h:2871
Logger(Logger &&v) noexcept=default
Logger & operator=(Logger &&v) noexcept=default
Logger & operator=(const Logger &)=default
~Logger()=default
Logger(const Logger &)=default
Logger()=default
Logger(std::nullptr_t)
Definition onnxruntime_cxx_api.h:2880
Logger(const OrtLogger *logger)
OrtLoggingLevel GetLoggingSeverityLevel() const noexcept
LoraAdapter holds a set of Lora Parameters loaded from a single file.
Definition onnxruntime_cxx_api.h:1420
static LoraAdapter CreateLoraAdapter(const std::basic_string< char > &adapter_path, OrtAllocator *allocator)
Wraps OrtApi::CreateLoraAdapter.
LoraAdapter(std::nullptr_t)
Definition onnxruntime_cxx_api.h:1424
static LoraAdapter CreateLoraAdapterFromArray(const void *bytes, size_t num_bytes, OrtAllocator *allocator)
Wraps OrtApi::CreateLoraAdapterFromArray.
Wrapper around OrtMapTypeInfo.
Definition onnxruntime_cxx_api.h:2101
ConstMapTypeInfo GetConst() const
Definition onnxruntime_cxx_api.h:2107
MapTypeInfo(OrtMapTypeInfo *p)
Used for interop with the C API.
Definition onnxruntime_cxx_api.h:2106
MapTypeInfo(std::nullptr_t)
Create an empty MapTypeInfo object, must be assigned a valid one to be used.
Definition onnxruntime_cxx_api.h:2105
Represents native memory allocation coming from one of the OrtAllocators registered with OnnxRuntime....
Definition onnxruntime_cxx_api.h:1027
MemoryAllocation(MemoryAllocation &&) noexcept
MemoryAllocation & operator=(const MemoryAllocation &)=delete
MemoryAllocation(const MemoryAllocation &)=delete
MemoryAllocation(OrtAllocator *allocator, void *p, size_t size)
size_t size() const
Definition onnxruntime_cxx_api.h:1036
Wrapper around OrtMemoryInfo.
Definition onnxruntime_cxx_api.h:1011
MemoryInfo(const char *name, OrtAllocatorType type, int id, OrtMemType mem_type)
MemoryInfo(std::nullptr_t)
No instance is created.
Definition onnxruntime_cxx_api.h:1013
MemoryInfo(OrtMemoryInfo *p)
Take ownership of a pointer created by C API.
Definition onnxruntime_cxx_api.h:1014
static MemoryInfo CreateCpu(OrtAllocatorType type, OrtMemType mem_type1)
ConstMemoryInfo GetConst() const
Definition onnxruntime_cxx_api.h:1018
MemoryInfo(const char *name, OrtMemoryInfoDeviceType device_type, uint32_t vendor_id, uint32_t device_id, OrtDeviceMemoryType mem_type, size_t alignment, OrtAllocatorType allocator_type)
Wrapper around CreateMemoryInfo_V2.
Options object used when compiling a model.
Definition onnxruntime_cxx_api.h:1690
ModelCompilationOptions & SetOutputModelWriteFunc(OrtWriteBufferFunc write_func, void *state)
ModelCompilationOptions & SetEpContextEmbedMode(bool embed_ep_context_in_model)
Wraps OrtApi::ModelCompilationOptions_SetEpContextEmbedMode.
ModelCompilationOptions & SetInputModelFromBuffer(const void *input_model_data, size_t input_model_data_size)
Wraps OrtApi::ModelCompilationOptions_SetInputModelFromBuffer.
ModelCompilationOptions & SetOutputModelBuffer(OrtAllocator *allocator, void **output_model_buffer_ptr, size_t *output_model_buffer_size_ptr)
Wraps OrtApi::ModelCompilationOptions_SetOutputModelBuffer.
ModelCompilationOptions & SetFlags(uint32_t flags)
Wraps OrtApi::ModelCompilationOptions_SetFlags.
ModelCompilationOptions & SetOutputModelExternalInitializersFile(const char *file_path, size_t initializer_size_threshold)
Wraps OrtApi::ModelCompilationOptions_SetOutputModelExternalInitializersFile.
ModelCompilationOptions(std::nullptr_t)
Create an empty ModelCompilationOptions object, must be assigned a valid one to be used.
Definition onnxruntime_cxx_api.h:1694
ModelCompilationOptions(const Env &env, ConstSessionOptions session_options)
Wraps OrtApi::CreateModelCompilationOptionsFromSessionOptions.
ModelCompilationOptions & SetOutputModelPath(const char *output_model_path)
Wraps OrtApi::ModelCompilationOptions_SetOutputModelPath.
ModelCompilationOptions & SetInputModelPath(const char *input_model_path)
Wraps OrtApi::ModelCompilationOptions_SetInputModelPath.
ModelCompilationOptions & SetOutputModelGetInitializerLocationFunc(OrtGetInitializerLocationFunc get_initializer_location_func, void *state)
ModelCompilationOptions & SetEpContextBinaryInformation(const char *output_directory, const char *model_name)
Wraps OrtApi::ModelCompilationOptions_SetEpContextBinaryInformation.
ModelCompilationOptions & SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level)
Wraps OrtApi::ModelCompilationOptions_SetGraphOptimizationLevel.
ModelCompilationOptions(const Env &env, const SessionOptions &session_options)
Wraps OrtApi::CreateModelCompilationOptionsFromSessionOptions.
ModelCompilationOptions & SetInputModel(const OrtModel *model)
Wraps OrtCompileApi::ModelCompilationOptions_SetInputModel.
Wrapper around OrtModel.
Definition onnxruntime_cxx_api.h:3530
Model(const std::vector< DomainOpsetPair > &opsets)
Model(OrtModel *p)
Take ownership of a pointer created by C API.
Definition onnxruntime_cxx_api.h:3534
std::pair< std::string, int > DomainOpsetPair
Definition onnxruntime_cxx_api.h:3531
Model(std::nullptr_t)
No instance is created.
Definition onnxruntime_cxx_api.h:3533
Wrapper around OrtModelMetadata.
Definition onnxruntime_cxx_api.h:1738
AllocatedStringPtr GetDescriptionAllocated(OrtAllocator *allocator) const
Returns a copy of the description.
std::vector< AllocatedStringPtr > GetCustomMetadataMapKeysAllocated(OrtAllocator *allocator) const
Returns a vector of copies of the custom metadata keys.
ModelMetadata(std::nullptr_t)
Create an empty ModelMetadata object, must be assigned a valid one to be used.
Definition onnxruntime_cxx_api.h:1742
AllocatedStringPtr GetGraphDescriptionAllocated(OrtAllocator *allocator) const
Returns a copy of the graph description.
AllocatedStringPtr GetProducerNameAllocated(OrtAllocator *allocator) const
Returns a copy of the producer name.
AllocatedStringPtr GetGraphNameAllocated(OrtAllocator *allocator) const
Returns a copy of the graph name.
AllocatedStringPtr LookupCustomMetadataMapAllocated(const char *key, OrtAllocator *allocator) const
Looks up a value by a key in the Custom Metadata map.
AllocatedStringPtr GetDomainAllocated(OrtAllocator *allocator) const
Returns a copy of the domain name.
int64_t GetVersion() const
Wraps OrtApi::ModelMetadataGetVersion.
Wrapper around OrtNode.
Definition onnxruntime_cxx_api.h:3397
Node(const std::string &operator_name, const std::string &operator_domain, const std::string &node_name, const std::vector< std::string > &input_names, const std::vector< std::string > &output_names)
Node()=default
Node(std::nullptr_t)
No instance is created.
Definition onnxruntime_cxx_api.h:3399
Node(const std::string &operator_name, const std::string &operator_domain, const std::string &node_name, const std::vector< std::string > &input_names, const std::vector< std::string > &output_names, std::vector< OpAttr > &attributes)
Wraps CreateNode. Node takes ownership of attributes on success and updates the OpAttr in attributes ...
Node(OrtNode *p)
Take ownership of a pointer created by C API.
Definition onnxruntime_cxx_api.h:3400
This struct provides life time management for custom op attribute.
Definition onnxruntime_cxx_api.h:2780
OpAttr(const char *name, const void *data, int len, OrtOpAttrType type)
OpAttr()=default
OpAttr(std::nullptr_t)
Definition onnxruntime_cxx_api.h:2785
ConstOpAttr GetConst() const
Definition onnxruntime_cxx_api.h:2788
Create and own custom defined operation.
Definition onnxruntime_cxx_api.h:3048
Op(OrtOp *)
Take ownership of the OrtOp.
static Op Create(const OrtKernelInfo *info, const char *op_name, const char *domain, int version, const char **type_constraint_names, const ONNXTensorElementDataType *type_constraint_values, size_t type_constraint_count, const OpAttr *attr_values, size_t attr_count, size_t input_count, size_t output_count)
Op(std::nullptr_t)
Create an empty Operator object, must be assigned a valid one to be used.
Definition onnxruntime_cxx_api.h:3052
void Invoke(const OrtKernelContext *context, const OrtValue *const *input_values, size_t input_count, OrtValue *const *output_values, size_t output_count)
void Invoke(const OrtKernelContext *context, const Value *input_values, size_t input_count, Value *output_values, size_t output_count)
Definition onnxruntime_cxx_api.h:3437
std::string domain
Definition onnxruntime_cxx_api.h:3438
int64_t version
Definition onnxruntime_cxx_api.h:3439
The PrepackedWeightsContainer.
Definition onnxruntime_cxx_api.h:897
PrepackedWeightsContainer()
Wraps OrtApi::CreatePrepackedWeightsContainer.
PrepackedWeightsContainer(OrtPrepackedWeightsContainer *p)
Definition onnxruntime_cxx_api.h:902
PrepackedWeightsContainer(std::nullptr_t)
Take ownership of a pointer created by C API.
Definition onnxruntime_cxx_api.h:900
Owning wrapper around OrtProfilingEvent.
Definition onnxruntime_cxx_api.h:1283
ProfilingEvent(std::nullptr_t)
No instance is created.
Definition onnxruntime_cxx_api.h:1284
ConstProfilingEvent GetConst() const
Definition onnxruntime_cxx_api.h:1308
ProfilingEvent(OrtProfilingEventCategory category, int32_t process_id, int32_t thread_id, const char *event_name, int64_t timestamp_us, int64_t duration_us, const std::unordered_map< std::string, std::string > &args={})
Wraps OrtEpApi::CreateProfilingEvent.
ProfilingEvent(OrtProfilingEvent *p)
Take ownership.
Definition onnxruntime_cxx_api.h:1285
ProfilingEvent(OrtProfilingEventCategory category, int32_t process_id, int32_t thread_id, const char *event_name, int64_t timestamp_us, int64_t duration_us, const char *const *arg_keys, const char *const *arg_values, size_t num_args)
Wraps OrtEpApi::CreateProfilingEvent.
RunOptions.
Definition onnxruntime_cxx_api.h:1448
int GetRunLogSeverityLevel() const
Wraps OrtApi::RunOptionsGetRunLogSeverityLevel.
RunOptions & SetTerminate()
Terminates all currently executing Session::Run calls that were made using this RunOptions instance.
RunOptions & DisableProfiling()
Disable profiling for this run.
RunOptions & SetSyncStream(OrtSyncStream *stream)
Associate a sync stream with the run options.
RunOptions & SetRunTag(const char *run_tag)
wraps OrtApi::RunOptionsSetRunTag
RunOptions & AddActiveLoraAdapter(const LoraAdapter &adapter)
Add the LoraAdapter to the list of active adapters. The setting does not affect RunWithBinding() call...
RunOptions & UnsetTerminate()
Clears the terminate flag so this RunOptions instance can be used in a new Session::Run call without ...
int GetRunLogVerbosityLevel() const
Wraps OrtApi::RunOptionsGetRunLogVerbosityLevel.
RunOptions(std::nullptr_t)
Create an empty RunOptions object, must be assigned a valid one to be used.
Definition onnxruntime_cxx_api.h:1449
RunOptions & SetRunLogVerbosityLevel(int)
Wraps OrtApi::RunOptionsSetRunLogVerbosityLevel.
RunOptions & SetRunLogSeverityLevel(int)
Wraps OrtApi::RunOptionsSetRunLogSeverityLevel.
RunOptions & EnableProfiling(const char *profile_file_prefix)
Enable profiling for this run.
RunOptions & AddConfigEntry(const char *config_key, const char *config_value)
Wraps OrtApi::AddRunConfigEntry.
const char * GetRunTag() const
Wraps OrtApi::RunOptionsGetRunTag.
RunOptions()
Wraps OrtApi::CreateRunOptions.
const char * GetConfigEntry(const char *config_key)
Wraps OrtApi::GetRunConfigEntry.
Wrapper around OrtSequenceTypeInfo.
Definition onnxruntime_cxx_api.h:2063
SequenceTypeInfo(std::nullptr_t)
Create an empty SequenceTypeInfo object, must be assigned a valid one to be used.
Definition onnxruntime_cxx_api.h:2067
ConstSequenceTypeInfo GetConst() const
Definition onnxruntime_cxx_api.h:2069
SequenceTypeInfo(OrtSequenceTypeInfo *p)
Used for interop with the C API.
Definition onnxruntime_cxx_api.h:2068
Wrapper around OrtSession.
Definition onnxruntime_cxx_api.h:1964
Session(std::nullptr_t)
Create an empty Session object, must be assigned a valid one to be used. Wraps OrtApi::CreateSession.
Definition onnxruntime_cxx_api.h:1966
static Session CreateModelEditorSession(const Env &env, const void *model_data, size_t model_data_length, const SessionOptions &options)
Wraps OrtModelEditorApi::CreateModelEditorSession.
UnownedSession GetUnowned() const
Definition onnxruntime_cxx_api.h:1995
Session(const Env &env, const char *model_path, const SessionOptions &options, OrtPrepackedWeightsContainer *prepacked_weights_container)
Wraps OrtApi::CreateSessionWithPrepackedWeightsContainer.
Session(const Env &env, const void *model_data, size_t model_data_length, const SessionOptions &options, OrtPrepackedWeightsContainer *prepacked_weights_container)
Wraps OrtApi::CreateSessionFromArrayWithPrepackedWeightsContainer.
Session(const Env &env, const Model &model, const SessionOptions &options)
Wraps OrtModelEditorApi::CreateSessionFromModel.
Session(OrtSession *p)
C API Interop.
Definition onnxruntime_cxx_api.h:1967
static Session CreateModelEditorSession(const Env &env, const char *model_path, const SessionOptions &options)
Wraps OrtModelEditorApi::CreateModelEditorSession.
Session(const Env &env, const char *model_path, const SessionOptions &options)
ConstSession GetConst() const
Definition onnxruntime_cxx_api.h:1994
Session(const Env &env, const void *model_data, size_t model_data_length, const SessionOptions &options)
Wraps OrtApi::CreateSessionFromArray.
Wrapper around OrtSessionOptions.
Definition onnxruntime_cxx_api.h:1678
SessionOptions(std::nullptr_t)
Create an empty SessionOptions object, must be assigned a valid one to be used.
Definition onnxruntime_cxx_api.h:1679
UnownedSessionOptions GetUnowned() const
Definition onnxruntime_cxx_api.h:1682
SessionOptions()
Wraps OrtApi::CreateSessionOptions.
ConstSessionOptions GetConst() const
Definition onnxruntime_cxx_api.h:1683
SessionOptions(OrtSessionOptions *p)
Used for interop with the C API.
Definition onnxruntime_cxx_api.h:1681
Definition onnxruntime_cxx_api.h:3082
SymbolicInteger & operator=(const SymbolicInteger &)=default
SymbolicInteger(const SymbolicInteger &)=default
int64_t AsInt() const
Definition onnxruntime_cxx_api.h:3103
int64_t i_
Definition onnxruntime_cxx_api.h:3110
const char * s_
Definition onnxruntime_cxx_api.h:3111
bool operator==(const SymbolicInteger &dim) const
Definition onnxruntime_cxx_api.h:3091
SymbolicInteger & operator=(SymbolicInteger &&)=default
SymbolicInteger(SymbolicInteger &&)=default
const char * AsSym() const
Definition onnxruntime_cxx_api.h:3104
SymbolicInteger(int64_t i)
Definition onnxruntime_cxx_api.h:3083
SymbolicInteger(const char *s)
Definition onnxruntime_cxx_api.h:3084
bool IsInt() const
Definition onnxruntime_cxx_api.h:3102
Provide access to per-node attributes and input shapes, so one could compute and set output shapes.
Definition onnxruntime_cxx_api.h:3081
Ints GetAttrInts(const char *attr_name)
Strings GetAttrStrings(const char *attr_name)
Status SetOutputShape(size_t indice, const Shape &shape, ONNXTensorElementDataType type=ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)
std::vector< SymbolicInteger > Shape
Definition onnxruntime_cxx_api.h:3116
std::vector< float > Floats
Definition onnxruntime_cxx_api.h:3133
std::string GetAttrString(const char *attr_name)
std::vector< int64_t > Ints
Definition onnxruntime_cxx_api.h:3128
ShapeInferContext(const OrtApi *ort_api, OrtShapeInferContext *ctx)
int64_t GetAttrInt(const char *attr_name)
size_t GetInputCount() const
Definition onnxruntime_cxx_api.h:3122
std::vector< std::string > Strings
Definition onnxruntime_cxx_api.h:3138
Floats GetAttrFloats(const char *attr_name)
const Shape & GetInputShape(size_t indice) const
Definition onnxruntime_cxx_api.h:3120
float GetAttrFloat(const char *attr_name)
The Status that holds ownership of OrtStatus received from C API Use it to safely destroy OrtStatus* ...
Definition onnxruntime_cxx_api.h:813
OrtErrorCode GetErrorCode() const
Status(const Exception &)
Creates status instance out of exception.
bool IsOK() const noexcept
Returns true if instance represents an OK (non-error) status.
Status(OrtStatus *status) noexcept
Takes ownership of OrtStatus instance returned from the C API.
std::string GetErrorMessage() const
Status()=default
Status(const std::exception &)
Creates status instance out of exception.
Status(const char *message, OrtErrorCode code)
Creates status instance out of null-terminated string message.
Status(std::nullptr_t) noexcept
Create an empty object, must be assigned a valid one to be used.
Definition onnxruntime_cxx_api.h:815
Definition onnxruntime_cxx_api.h:1106
SyncStream(OrtSyncStream *p)
Definition onnxruntime_cxx_api.h:1110
SyncStream(std::nullptr_t)
< Create an empty SyncStream object, must be assigned a valid one to be used
Definition onnxruntime_cxx_api.h:1108
The TensorRTOptions (V2)
Definition onnxruntime_cxx_api.h:859
void Update(const std::unordered_map< std::string, std::string > &options)
Wrapper around OrtApi::UpdateTensorRTProviderOptions.
void UpdateWithValue(const char *key, void *value)
Wrapper around OrtApi::GetTensorRTProviderOptionsByName.
std::string GetTensorRTProviderOptionsAsString() const
void * GetOptionByName(const char *name) const
Wrapper around OrtApi::GetTensorRTProviderOptionsAsString.
TensorRTProviderOptions(std::nullptr_t)
Definition onnxruntime_cxx_api.h:860
TensorRTProviderOptions()
Wraps OrtApi::CreateTensorRTProviderOptionsV2.
Wrapper around OrtTensorTypeAndShapeInfo.
Definition onnxruntime_cxx_api.h:2029
TensorTypeAndShapeInfo(std::nullptr_t)
Create an empty TensorTypeAndShapeInfo object, must be assigned a valid one to be used.
Definition onnxruntime_cxx_api.h:2034
ConstTensorTypeAndShapeInfo GetConst() const
Definition onnxruntime_cxx_api.h:2045
TensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo *p)
Used for interop with the C API.
Definition onnxruntime_cxx_api.h:2036
TensorTypeAndShapeInfo(ONNXTensorElementDataType element_type, const std::vector< int64_t > &dims, const std::vector< std::string > *symbolic_dims=nullptr)
The ThreadingOptions.
Definition onnxruntime_cxx_api.h:829
ThreadingOptions & SetGlobalCustomThreadCreationOptions(void *ort_custom_thread_creation_options)
Wraps OrtApi::SetGlobalCustomThreadCreationOptions.
ThreadingOptions()
Wraps OrtApi::CreateThreadingOptions.
ThreadingOptions & SetGlobalInterOpNumThreads(int inter_op_num_threads)
Wraps OrtApi::SetGlobalInterOpNumThreads.
ThreadingOptions & SetGlobalCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn)
Wraps OrtApi::SetGlobalCustomCreateThreadFn.
ThreadingOptions & SetGlobalCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn)
Wraps OrtApi::SetGlobalCustomJoinThreadFn.
ThreadingOptions & SetGlobalSpinControl(int allow_spinning)
Wraps OrtApi::SetGlobalSpinControl.
ThreadingOptions & SetGlobalDenormalAsZero()
Wraps OrtApi::SetGlobalDenormalAsZero.
ThreadingOptions & SetGlobalIntraOpNumThreads(int intra_op_num_threads)
Wraps OrtApi::SetGlobalIntraOpNumThreads.
Type information that may contain either TensorTypeAndShapeInfo or the information about contained se...
Definition onnxruntime_cxx_api.h:2135
static TypeInfo CreateOptionalTypeInfo(ConstTypeInfo contained_type)
static TypeInfo CreateSequenceTypeInfo(ConstTypeInfo sequence_type)
static TypeInfo CreateTensorInfo(ConstTensorTypeAndShapeInfo tensor_info)
static TypeInfo CreateSparseTensorInfo(ConstTensorTypeAndShapeInfo sparse_tensor_info)
TypeInfo(std::nullptr_t)
Create an empty TypeInfo object, must be assigned a valid one to be used.
Definition onnxruntime_cxx_api.h:2140
static TypeInfo CreateMapTypeInfo(ONNXTensorElementDataType key_type, ConstTypeInfo value_type)
ConstTypeInfo GetConst() const
Definition onnxruntime_cxx_api.h:2151
TypeInfo(OrtTypeInfo *p)
C API Interop.
Definition onnxruntime_cxx_api.h:2141
Wrapper around OrtValue.
Definition onnxruntime_cxx_api.h:2504
static Value CreateSparseTensor(const OrtMemoryInfo *info, void *p_data, const Shape &dense_shape, const Shape &values_shape, ONNXTensorElementDataType type)
Creates an OrtValue instance containing SparseTensor. This constructs a sparse tensor that makes use ...
static Value CreateSparseTensor(const OrtMemoryInfo *info, T *p_data, const Shape &dense_shape, const Shape &values_shape)
This is a simple forwarding method to the other overload that helps deducing data type enum value fro...
Value & operator=(Value &&)=default
static Value CreateSparseTensor(OrtAllocator *allocator, const Shape &dense_shape, ONNXTensorElementDataType type)
Creates an instance of OrtValue containing sparse tensor. The created instance has no data....
Value(Value &&)=default
Value(std::nullptr_t)
Create an empty Value object, must be assigned a valid one to be used.
Definition onnxruntime_cxx_api.h:2510
static Value CreateTensor(const OrtMemoryInfo *info, T *p_data, size_t p_data_element_count, const int64_t *shape, size_t shape_len)
Creates a tensor with a user supplied buffer. Wraps OrtApi::CreateTensorWithDataAsOrtValue.
static Value CreateSparseTensor(OrtAllocator *allocator, const Shape &dense_shape)
This is a simple forwarding method to the below CreateSparseTensor. This helps to specify data type e...
static Value CreateTensor(OrtAllocator *allocator, const int64_t *shape, size_t shape_len, ONNXTensorElementDataType type)
Creates an OrtValue with a tensor using the supplied OrtAllocator. Wraps OrtApi::CreateTensorAsOrtVal...
UnownedValue GetUnowned() const
Definition onnxruntime_cxx_api.h:2515
static Value CreateSequence(const std::vector< Value > &values)
Creates an OrtValue with a Sequence Onnx type representation. The API would ref-count the supplied Or...
static Value CreateMap(const Value &keys, const Value &values)
Creates an OrtValue with a Map Onnx type representation. The API would ref-count the supplied OrtValu...
static Value CreateTensor(const OrtMemoryInfo *info, void *p_data, size_t p_data_byte_count, const int64_t *shape, size_t shape_len, ONNXTensorElementDataType type)
Creates a tensor with a user supplied buffer. Wraps OrtApi::CreateTensorWithDataAsOrtValue.
static Value CreateTensor(OrtAllocator *allocator, const int64_t *shape, size_t shape_len)
Creates an OrtValue with a tensor using a supplied OrtAllocator. Wraps OrtApi::CreateTensorAsOrtValue...
static Value CreateOpaque(const char *domain, const char *type_name, const T &value)
Creates an OrtValue wrapping an Opaque type. This is used for experimental support of non-tensor type...
static Value CreateTensor(OrtAllocator *deleter, void *p_data, size_t p_data_byte_count, const int64_t *shape, size_t shape_len, ONNXTensorElementDataType type)
Creates a tensor with a user supplied buffer. Wraps OrtApi::CreateTensorWithDataAndDeleterAsOrtValue.
ConstValue GetConst() const
Definition onnxruntime_cxx_api.h:2514
Definition onnxruntime_cxx_api.h:3429
int64_t index
Definition onnxruntime_cxx_api.h:3433
ConstNode node
Definition onnxruntime_cxx_api.h:3430
Wrapper around OrtValueInfo.
Definition onnxruntime_cxx_api.h:3334
ConstValueInfo GetConst() const
Definition onnxruntime_cxx_api.h:3344
ValueInfo(std::nullptr_t)
Definition onnxruntime_cxx_api.h:3336
ValueInfo(const std::string &name, const ConstTypeInfo &type_info)
ValueInfo(OrtValueInfo *p)
Take ownership of a pointer created by C API.
Definition onnxruntime_cxx_api.h:3338
ValueInfo()=default
Definition onnxruntime_cxx_api.h:775
AllocatedFree(OrtAllocator *allocator)
Definition onnxruntime_cxx_api.h:777
OrtAllocator * allocator_
Definition onnxruntime_cxx_api.h:776
void operator()(void *ptr) const
Definition onnxruntime_cxx_api.h:779
Base & operator=(Base &&v) noexcept
Definition onnxruntime_cxx_api.h:761
constexpr contained_type & operator*() const noexcept
Definition onnxruntime_cxx_api.h:768
typename Unowned< T >::Type contained_type
Definition onnxruntime_cxx_api.h:750
Base(Base &&v) noexcept
Definition onnxruntime_cxx_api.h:760
Base(const Base &)=default
constexpr Base(contained_type *p) noexcept
Definition onnxruntime_cxx_api.h:753
Base & operator=(const Base &)=default
Used internally by the C++ API. C++ wrapper types inherit from this. This is a zero cost abstraction ...
Definition onnxruntime_cxx_api.h:703
Base(Base &&v) noexcept
Definition onnxruntime_cxx_api.h:715
constexpr Base()=default
constexpr contained_type & operator*() const noexcept
Definition onnxruntime_cxx_api.h:723
contained_type * release()
Relinquishes ownership of the contained C object pointer The underlying object is not destroyed.
Definition onnxruntime_cxx_api.h:727
Base(const Base &)=delete
constexpr Base(contained_type *p) noexcept
Definition onnxruntime_cxx_api.h:707
Base & operator=(const Base &)=delete
Base & operator=(Base &&v) noexcept
Definition onnxruntime_cxx_api.h:716
contained_type * p_
Definition onnxruntime_cxx_api.h:734
~Base()
Definition onnxruntime_cxx_api.h:708
T contained_type
Definition onnxruntime_cxx_api.h:704
Definition onnxruntime_cxx_api.h:909
const std::basic_string< char > GetFilePath() const
Definition onnxruntime_cxx_api.h:3444
std::vector< ConstNode > GetNodes() const
std::vector< ConstValueInfo > GetInputs() const
ConstNode GetParentNode() const
int64_t GetOnnxIRVersion() const
std::basic_string< char > GetModelPath() const
Graph GetGraphView(const std::vector< ConstNode > &nodes) const
ModelMetadata GetModelMetadata() const
Wraps OrtApi::Graph_GetModelMetadata.
std::vector< ConstValueInfo > GetInitializers() const
std::string GetName() const
std::vector< ConstValueInfo > GetOutputs() const
std::vector< OperatorSet > GetOperatorSets() const
Definition onnxruntime_cxx_api.h:2683
std::vector< Value > GetOutputValues(OrtAllocator *) const
std::vector< std::string > GetOutputNames(OrtAllocator *) const
std::vector< Value > GetOutputValues() const
std::vector< std::string > GetOutputNames() const
Definition onnxruntime_cxx_api.h:3544
std::pair< int, int > GetSinceVersion() const
Wraps OrtEpApi::KernelDef_GetExecutionProvider.
const char * GetDomain() const
Wraps OrtEpApi::KernelDef_GetSinceVersion.
OrtMemType GetOutputMemType(size_t output_index) const
const char * GetExecutionProvider() const
Wraps OrtEpApi::KernelDef_GetInputMemType.
OrtMemType GetInputMemType(size_t input_index) const
Wraps OrtEpApi::KernelDef_GetOutputMemType.
const char * GetOperatorType() const
< Wraps OrtEpApi::KernelDef_GetOperatorType
Definition onnxruntime_cxx_api.h:3356
std::vector< ConstValueInfo > GetOutputs() const
std::vector< ConstValueInfo > GetImplicitInputs() const
std::string GetName() const
std::string GetDomain() const
std::vector< AttrNameSubgraph > GetSubgraphs() const
ConstGraphImpl< detail::Unowned< const OrtGraph > > GetGraph() const
std::string GetOperatorType() const
std::vector< ConstOpAttr > GetAttributes() const
std::vector< ConstValueInfo > GetInputs() const
Status GetAttributeByName(const std::string &name, ConstOpAttr &attr) const
std::string GetEpName() const
Definition onnxruntime_cxx_api.h:2752
std::string GetName() const
Status GetValue(R &out) const
Status GetTensorAttributeAsOrtValue(Value &) const
Status GetValueArray(std::vector< R > &out) const
OrtOpAttrType GetType() const
Definition onnxruntime_cxx_api.h:1244
int64_t GetTimestampUs() const
Get the start timestamp in microseconds. Wraps OrtEpApi::ProfilingEvent_GetTimestampUs.
const char * GetName() const
Get the event name. Wraps OrtEpApi::ProfilingEvent_GetName.
const char * GetArgValue(const char *key) const
Get the value of an event argument by key. Wraps OrtEpApi::ProfilingEvent_GetArgValue.
int64_t GetDurationUs() const
Get the duration in microseconds. Wraps OrtEpApi::ProfilingEvent_GetDurationUs.
OrtProfilingEventCategory GetCategory() const
Get the event category. Wraps OrtEpApi::ProfilingEvent_GetCategory.
Definition onnxruntime_cxx_api.h:1813
std::vector< std::string > GetOutputNames() const
TypeInfo GetInputTypeInfo(size_t index) const
Wraps OrtApi::SessionGetInputTypeInfo.
std::vector< ConstEpAssignedSubgraph > GetEpGraphAssignmentInfo() const
Returns information on the subgraph/nodes assigned to execution providers in the session.
size_t GetOutputCount() const
Returns the number of model outputs.
std::vector< ValueInfo > GetOutputs() const
int GetOpset(const std::string &domain) const
Wraps OrtApi::SessionGetOpsetForDomain.
uint64_t GetProfilingStartTimeNs() const
Wraps OrtApi::SessionGetProfilingStartTimeNs.
std::vector< ConstEpDevice > GetEpDeviceForOutputs() const
Wrapper for OrtApi::SessionGetEpDeviceForOutputs.
std::vector< std::string > GetOverridableInitializerNames() const
ModelMetadata GetModelMetadata() const
Wraps OrtApi::SessionGetModelMetadata.
size_t GetInputCount() const
Returns the number of model inputs.
TypeInfo GetOutputTypeInfo(size_t index) const
Wraps OrtApi::SessionGetOutputTypeInfo.
std::vector< std::string > GetInputNames() const
AllocatedStringPtr GetOverridableInitializerNameAllocated(size_t index, OrtAllocator *allocator) const
Returns a copy of the overridable initializer name at then specified index.
std::vector< ConstEpDevice > GetEpDeviceForInputs() const
Wrapper for OrtApi::SessionGetEpDeviceForInputs.
AllocatedStringPtr GetOutputNameAllocated(size_t index, OrtAllocator *allocator) const
Returns a copy of output name at then specified index.
size_t GetOverridableInitializerCount() const
Returns the number of inputs that have defaults that can be overridden.
std::vector< ConstMemoryInfo > GetMemoryInfoForOutputs() const
Wrapper for OrtApi::SessionGetMemoryInfoForOutputs.
AllocatedStringPtr GetInputNameAllocated(size_t index, OrtAllocator *allocator) const
Returns a copy of input name at the specified index.
std::vector< ConstMemoryInfo > GetMemoryInfoForInputs() const
Wrapper for OrtApi::SessionGetMemoryInfoForInputs.
std::vector< ValueInfo > GetInputs() const
TypeInfo GetOverridableInitializerTypeInfo(size_t index) const
Wraps OrtApi::SessionGetOverridableInitializerTypeInfo.
Definition onnxruntime_cxx_api.h:2180
void GetStringTensorContent(void *buffer, size_t buffer_length, size_t *offsets, size_t offsets_count) const
The API copies all of the UTF-8 encoded string data contained within a tensor or a sparse tensor into...
void GetStringTensorElement(size_t buffer_length, size_t element_index, void *buffer) const
The API copies UTF-8 encoded bytes for the requested string element contained within a tensor or a sp...
TensorTypeAndShapeInfo GetSparseTensorIndicesTypeShapeInfo(OrtSparseIndicesFormat format) const
The API returns type and shape information for the specified indices. Each supported indices have the...
const void * GetTensorRawData() const
Returns a non-typed pointer to a tensor contained data.
std::string GetStringTensorElement(size_t element_index) const
Returns string tensor UTF-8 encoded string element. Use of this API is recommended over GetStringTens...
size_t GetStringTensorElementLength(size_t element_index) const
The API returns a byte length of UTF-8 encoded string element contained in either a tensor or a spare...
size_t GetStringTensorDataLength() const
This API returns a full length of string data contained within either a tensor or a sparse Tensor....
bool IsSparseTensor() const
Returns true if the OrtValue contains a sparse tensor.
TypeInfo GetTypeInfo() const
The API returns type information for data contained in a tensor. For sparse tensors it returns type i...
const R * GetSparseTensorIndicesData(OrtSparseIndicesFormat indices_format, size_t &num_indices) const
The API retrieves a pointer to the internal indices buffer. The API merely performs a convenience dat...
bool IsTensor() const
Returns true if Value is a tensor, false for other types like map/sequence/etc.
ConstMemoryInfo GetTensorMemoryInfo() const
This API returns information about the memory allocation used to hold data.
size_t GetTensorSizeInBytes() const
Returns the total size of the tensor data in bytes. Throws an exception if the OrtValue does not cont...
const R * GetSparseTensorValues() const
The API returns a pointer to an internal buffer of the sparse tensor containing non-zero values....
TensorTypeAndShapeInfo GetTensorTypeAndShapeInfo() const
The API returns type information for data contained in a tensor. For sparse tensors it returns type i...
Value GetValue(int index, OrtAllocator *allocator) const
size_t GetCount() const
< Return true if OrtValue contains data and returns false if the OrtValue is a None
void GetOpaqueData(const char *domain, const char *type_name, R &) const
Obtains a pointer to a user defined data for experimental purposes.
TensorTypeAndShapeInfo GetSparseTensorValuesTypeAndShapeInfo() const
The API returns type and shape information for stored non-zero values of the sparse tensor....
void GetTensorElementTypeAndShapeDataReference(ONNXTensorElementDataType &elem_type, Shape &shape) const
Returns the tensor's element type and a reference to the tensor's internal shape data....
const R * GetTensorData() const
Returns a const typed pointer to the tensor contained data. No type checking is performed,...
OrtSparseFormat GetSparseFormat() const
The API returns the sparse data format this OrtValue holds in a sparse tensor. If the sparse tensor w...
Definition onnxruntime_cxx_api.h:3299
Status GetInitializer(ConstValue &value) const
< A wrapper around OrtApi::ValueInfo_GetInitializerValue
std::string GetName() const
< A wrapper around OrtApi::GetValueInfoName
bool IsFromOuterScope() const
< A wrapper around OrtApi::ValueInfo_IsFromOuterScope
Status GetExternalInitializerInfo(ExternalInitializerInfo &info) const
< A wrapper around OrtApi::ValueInfo_GetExternalInitializerInfo
bool IsConstantInitializer() const
< A wrapper around OrtApi::ValueInfo_IsConstantInitializer
std::vector< ValueInfoConsumerProducerInfo > GetConsumers() const
< A wrapper around OrtApi::ValueInfo_GetValueConsumers
bool IsGraphOutput() const
< A wrapper around OrtApi::ValueInfo_IsGraphOutput
bool IsRequiredGraphInput() const
< A wrapper around OrtApi::ValueInfo_IsRequiredGraphInput
ConstTypeInfo TypeInfo() const
< A wrapper around OrtApi::GetValueInfoTypeInfo
ValueInfoConsumerProducerInfo GetProducerNode() const
bool IsOptionalGraphInput() const
< A wrapper around OrtApi::ValueInfo_IsOptionalGraphInput
Definition onnxruntime_cxx_api.h:1211
std::string GetDomain() const
std::string GetOperatorType() const
std::string GetName() const
Definition onnxruntime_cxx_api.h:1228
std::vector< ConstEpAssignedNode > GetNodes() const
Definition onnxruntime_cxx_api.h:1136
const char * EpName() const
const char * EpVendor() const
ConstKeyValuePairs EpOptions() const
ConstHardwareDevice Device() const
ConstMemoryInfo GetMemoryInfo(OrtDeviceMemoryType memory_type) const
Wraps EpDevice_MemoryInfo.
SyncStream CreateSyncStream(ConstKeyValuePairs stream_options={}) const
ConstKeyValuePairs EpMetadata() const
Definition onnxruntime_cxx_api.h:3473
void SetInputs(std::vector< ValueInfo > &inputs)
void SetOutputs(std::vector< ValueInfo > &outputs)
void AddNode(Node &node)
void AddInitializer(const std::string &name, Value &initializer, bool data_is_external)
Definition onnxruntime_cxx_api.h:1117
OrtHardwareDeviceType Type() const
const char * Vendor() const
ConstKeyValuePairs Metadata() const
Definition onnxruntime_cxx_api.h:2694
void BindOutput(const char *name, const Value &)
void BindInput(const char *name, const Value &)
void BindOutput(const char *name, const OrtMemoryInfo *)
Definition onnxruntime_cxx_api.h:949
void GetKeyValuePairs(std::vector< const char * > &keys, std::vector< const char * > &values) const
std::unordered_map< std::string, std::string > GetKeyValuePairs() const
const char * GetValue(const char *key) const
Definition onnxruntime_cxx_api.h:2087
ONNXTensorElementDataType GetMapKeyType() const
Wraps OrtApi::GetMapKeyType.
TypeInfo GetMapValueType() const
Wraps OrtApi::GetMapValueType.
Definition onnxruntime_cxx_api.h:988
std::string GetAllocatorName() const
Wrapper MemoryInfoGetName.
int GetDeviceId() const
Wrapper MemoryInfoGetId.
OrtMemType GetMemoryType() const
Wrapper MemoryInfoGetMemType.
OrtDeviceMemoryType GetDeviceMemoryType() const
Wrapper MemoryInfoGetDeviceMemType.
OrtMemoryInfoDeviceType GetDeviceType() const
Wrapper MemoryInfoGetDeviceType.
OrtAllocatorType GetAllocatorType() const
Wrapper MemoryInfoGetType.
uint32_t GetVendorId() const
Wrapper MemoryInfoGetVendorId.
bool operator==(const MemoryInfoImpl< U > &o) const
Definition onnxruntime_cxx_api.h:3513
void AddGraph(Graph &graph)
Definition onnxruntime_cxx_api.h:3653
std::string GetInputName(size_t index) const
int GetSinceVersion() const
< Wraps OrtEpApi::OpSchema_GetSinceVersion
size_t GetNumOutputs() const
Wraps OrtEpApi::OpSchema_GetOutputName.
ConstOpSchemaTypeConstraint GetOutputTypeConstraint(size_t index) const
Wraps OrtEpApi::OpSchema_GetTypeConstraintCount.
size_t GetTypeConstraintCount() const
Wraps OrtEpApi::OpSchema_GetTypeConstraint. Returns the i-th type constraint.
ConstOpSchemaTypeConstraint GetInputTypeConstraint(size_t index) const
Wraps OrtEpApi::OpSchema_GetNumOutputs.
size_t GetNumInputs() const
Wraps OrtEpApi::OpSchema_GetInputName.
std::string GetOutputName(size_t index) const
ConstOpSchemaTypeConstraint GetTypeConstraint(size_t index) const
Definition onnxruntime_cxx_api.h:3628
std::vector< size_t > GetOutputIndices() const
std::vector< size_t > GetInputIndices() const
Wraps OrtEpApi::OpSchemaTypeConstraint_GetOutputIndices.
std::vector< std::string > GetAllowedTypes() const
Wraps OrtEpApi::OpSchemaTypeConstraint_GetInputIndices.
std::string GetTypeParamName() const
< Wraps OrtEpApi::OpSchemaTypeConstraint_GetTypeParamName
Definition onnxruntime_cxx_api.h:2074
TypeInfo GetOptionalElementType() const
Wraps OrtApi::CastOptionalTypeToContainedTypeInfo.
Definition onnxruntime_cxx_api.h:2163
const char ** str
Definition onnxruntime_cxx_api.h:2168
const int64_t * values_shape
Definition onnxruntime_cxx_api.h:2164
size_t values_shape_len
Definition onnxruntime_cxx_api.h:2165
const void * p_data
Definition onnxruntime_cxx_api.h:2167
Definition onnxruntime_cxx_api.h:1313
Ort::Status AddEvents(const std::vector< ProfilingEvent > &events)
Ort::Status AddEvents(const OrtProfilingEvent *const *events, size_t num_events)
Adds profiling events to this container. Events are copied. Wraps OrtEpApi::ProfilingEventsContainer_...
Definition onnxruntime_cxx_api.h:2050
TypeInfo GetSequenceElementType() const
Wraps OrtApi::GetSequenceElementType.
Definition onnxruntime_cxx_api.h:1877
void SetEpDynamicOptions(const char *const *keys, const char *const *values, size_t kv_len)
Set DynamicOptions for EPs (Execution Providers)
AllocatedStringPtr EndProfilingAllocated(OrtAllocator *allocator)
End profiling and return a copy of the profiling file name.
void FinalizeModelEditorSession(const Model &model, const SessionOptions &options, OrtPrepackedWeightsContainer *prepacked_weights_container=nullptr)
void Run(const RunOptions &run_options, const IoBinding &)
Wraps OrtApi::RunWithBinding.
void RunAsync(const RunOptions &run_options, const char *const *input_names, const Value *input_values, size_t input_count, const char *const *output_names, Value *output_values, size_t output_count, RunAsyncCallbackFn callback, void *user_data)
Run the model asynchronously in a thread owned by intra op thread pool.
std::vector< Value > Run(const RunOptions &run_options, const char *const *input_names, const Value *input_values, size_t input_count, const char *const *output_names, size_t output_count)
Run the model returning results in an Ort allocated vector.
void Run(const RunOptions &run_options, const char *const *input_names, const Value *input_values, size_t input_count, const char *const *output_names, Value *output_values, size_t output_count)
Run the model returning results in user provided outputs Same as Run(const RunOptions&,...
Definition onnxruntime_cxx_api.h:2174
const int64_t * shape
Definition onnxruntime_cxx_api.h:2175
size_t shape_len
Definition onnxruntime_cxx_api.h:2176
Definition onnxruntime_cxx_api.h:3703
Status StoreWeightData(void **buffer_data_ptrs, size_t *buffer_sizes, size_t num_buffers)
Definition onnxruntime_cxx_api.h:1098
void * GetHandle()
Wraps SyncStream_GetHandle.
Definition onnxruntime_cxx_api.h:2000
size_t GetElementCount() const
Wraps OrtApi::GetTensorShapeElementCount.
void GetDimensions(int64_t *values, size_t values_count) const
Wraps OrtApi::GetDimensions.
std::vector< int64_t > GetShape() const
Uses GetDimensionsCount & GetDimensions to return a std::vector of the shape.
std::vector< const char * > GetSymbolicDimensions() const
void GetSymbolicDimensions(const char **values, size_t values_count) const
Wraps OrtApi::GetSymbolicDimensions.
size_t GetDimensionsCount() const
Wraps OrtApi::GetDimensionsCount.
ONNXTensorElementDataType GetElementType() const
Wraps OrtApi::GetTensorElementType.
bool HasShape() const
Wraps OrtApi::TensorTypeAndShape_HasShape.
Definition onnxruntime_cxx_api.h:2112
ONNXType GetONNXType() const
ConstSequenceTypeInfo GetSequenceTypeInfo() const
Wraps OrtApi::CastTypeInfoToSequenceTypeInfo.
ConstMapTypeInfo GetMapTypeInfo() const
Wraps OrtApi::CastTypeInfoToMapTypeInfo.
ConstOptionalTypeInfo GetOptionalTypeInfo() const
wraps OrtApi::CastTypeInfoToOptionalTypeInfo
ConstTensorTypeAndShapeInfo GetTensorTypeAndShapeInfo() const
Wraps OrtApi::CastTypeInfoToTensorInfo.
This is a tagging template type. Use it with Base<T> to indicate that the C++ interface object has no...
Definition onnxruntime_cxx_api.h:679
T Type
Definition onnxruntime_cxx_api.h:680
Definition onnxruntime_cxx_api.h:2362
void FillStringTensorElement(const char *s, size_t index)
Set a single string in a string tensor.
R * GetTensorMutableData()
Returns a non-const typed pointer to an OrtValue/Tensor contained buffer No type checking is performe...
R & At(const std::vector< int64_t > &location)
void UseBlockSparseIndices(const Shape &indices_shape, int32_t *indices_data)
Supplies BlockSparse format specific indices and marks the contained sparse tensor as being a BlockSp...
void FillSparseTensorBlockSparse(const OrtMemoryInfo *data_mem_info, const OrtSparseValuesParam &values, const Shape &indices_shape, const int32_t *indices_data)
The API will allocate memory using the allocator instance supplied to the CreateSparseTensor() API an...
void * GetTensorMutableRawData()
Returns a non-typed non-const pointer to a tensor contained data.
void UseCooIndices(int64_t *indices_data, size_t indices_num)
Supplies COO format specific indices and marks the contained sparse tensor as being a COO format tens...
void FillSparseTensorCoo(const OrtMemoryInfo *data_mem_info, const OrtSparseValuesParam &values_param, const int64_t *indices_data, size_t indices_num)
The API will allocate memory using the allocator instance supplied to the CreateSparseTensor() API an...
void FillStringTensor(const char *const *s, size_t s_len)
Set all strings at once in a string tensor.
void UseCsrIndices(int64_t *inner_data, size_t inner_num, int64_t *outer_data, size_t outer_num)
Supplies CSR format specific indices and marks the contained sparse tensor as being a CSR format tens...
void FillSparseTensorCsr(const OrtMemoryInfo *data_mem_info, const OrtSparseValuesParam &values, const int64_t *inner_indices_data, size_t inner_indices_num, const int64_t *outer_indices_data, size_t outer_indices_num)
The API will allocate memory using the allocator instance supplied to the CreateSparseTensor() API an...
char * GetResizedStringTensorElementBuffer(size_t index, size_t buffer_length)
Allocate if necessary and obtain a pointer to a UTF-8 encoded string element buffer indexed by the fl...
Memory allocation interface.
Definition onnxruntime_c_api.h:355
void(* Free)(struct OrtAllocator *this_, void *p)
Free a block of memory previously allocated with OrtAllocator::Alloc.
Definition onnxruntime_c_api.h:362
const OrtApi *(* GetApi)(uint32_t version)
Get a pointer to the requested version of the OrtApi.
Definition onnxruntime_c_api.h:936
The C API.
Definition onnxruntime_c_api.h:1291
const OrtEpApi *(* GetEpApi)(void)
Get the OrtEpApi instance for implementing an execution provider.
Definition onnxruntime_c_api.h:5764
const OrtInteropApi *(* GetInteropApi)(void)
Get the EP Interop API instance.
Definition onnxruntime_c_api.h:7008
const OrtCompileApi *(* GetCompileApi)(void)
Get the Compile API instance.
Definition onnxruntime_c_api.h:5496
void(* ReleaseTensorRTProviderOptions)(OrtTensorRTProviderOptionsV2 *input)
Release an OrtTensorRTProviderOptionsV2.
Definition onnxruntime_c_api.h:3547
const OrtModelEditorApi *(* GetModelEditorApi)(void)
Get the Model Editor API instance.
Definition onnxruntime_c_api.h:5438
void(* ReleaseCUDAProviderOptions)(OrtCUDAProviderOptionsV2 *input)
Release an OrtCUDAProviderOptionsV2.
Definition onnxruntime_c_api.h:4050
CUDA Provider Options.
Definition onnxruntime_c_api.h:636
The OrtCompileApi struct provides functions to compile ONNX models.
Definition onnxruntime_c_api.h:7993
Definition onnxruntime_c_api.h:7466
int(* GetVariadicInputHomogeneity)(const struct OrtCustomOp *op)
Definition onnxruntime_c_api.h:7512
OrtCustomOpInputOutputCharacteristic(* GetOutputCharacteristic)(const struct OrtCustomOp *op, size_t index)
Definition onnxruntime_c_api.h:7496
size_t(* GetInputTypeCount)(const struct OrtCustomOp *op)
Definition onnxruntime_c_api.h:7484
int(* GetVariadicOutputMinArity)(const struct OrtCustomOp *op)
Definition onnxruntime_c_api.h:7516
size_t(* GetAliasMap)(int **input_index, int **output_index)
Definition onnxruntime_c_api.h:7549
int(* GetStartVersion)(const struct OrtCustomOp *op)
Definition onnxruntime_c_api.h:7534
void(* ReleaseMayInplace)(int *input_index, int *output_index)
Definition onnxruntime_c_api.h:7546
const char *(* GetName)(const struct OrtCustomOp *op)
Definition onnxruntime_c_api.h:7477
size_t(* GetOutputTypeCount)(const struct OrtCustomOp *op)
Definition onnxruntime_c_api.h:7486
void(* KernelDestroy)(void *op_kernel)
Definition onnxruntime_c_api.h:7492
int(* GetVariadicOutputHomogeneity)(const struct OrtCustomOp *op)
Definition onnxruntime_c_api.h:7521
OrtMemType(* GetInputMemoryType)(const struct OrtCustomOp *op, size_t index)
Definition onnxruntime_c_api.h:7503
void *(* CreateKernel)(const struct OrtCustomOp *op, const OrtApi *api, const OrtKernelInfo *info)
Definition onnxruntime_c_api.h:7473
uint32_t version
Definition onnxruntime_c_api.h:7467
ONNXTensorElementDataType(* GetInputType)(const struct OrtCustomOp *op, size_t index)
Definition onnxruntime_c_api.h:7483
void(* ReleaseAliasMap)(int *input_index, int *output_index)
Definition onnxruntime_c_api.h:7550
OrtCustomOpInputOutputCharacteristic(* GetInputCharacteristic)(const struct OrtCustomOp *op, size_t index)
Definition onnxruntime_c_api.h:7495
const char *(* GetExecutionProviderType)(const struct OrtCustomOp *op)
Definition onnxruntime_c_api.h:7480
ONNXTensorElementDataType(* GetOutputType)(const struct OrtCustomOp *op, size_t index)
Definition onnxruntime_c_api.h:7485
int(* GetVariadicInputMinArity)(const struct OrtCustomOp *op)
Definition onnxruntime_c_api.h:7507
OrtStatusPtr(* InferOutputShapeFn)(const struct OrtCustomOp *op, OrtShapeInferContext *)
Definition onnxruntime_c_api.h:7531
int(* GetEndVersion)(const struct OrtCustomOp *op)
Definition onnxruntime_c_api.h:7535
OrtStatusPtr(* CreateKernelV2)(const struct OrtCustomOp *op, const OrtApi *api, const OrtKernelInfo *info, void **kernel)
Definition onnxruntime_c_api.h:7524
size_t(* GetMayInplace)(int **input_index, int **output_index)
Definition onnxruntime_c_api.h:7542
OrtStatusPtr(* KernelComputeV2)(void *op_kernel, OrtKernelContext *context)
Definition onnxruntime_c_api.h:7529
void(* KernelCompute)(void *op_kernel, OrtKernelContext *context)
Definition onnxruntime_c_api.h:7491
Configuration options for creating an OrtEnv.
Definition onnxruntime_c_api.h:1203
The OrtEpApi struct provides functions that are relevant to the implementation of an execution provid...
Definition onnxruntime_ep_c_api.h:1021
The OrtEpFactory provides functions to create and manage execution providers.
Definition onnxruntime_ep_c_api.h:2596
The OrtEp struct provides functions to implement for an execution provider.
Definition onnxruntime_ep_c_api.h:2118
The OrtInteropApi struct provides functions for external resource interop with execution providers.
Definition onnxruntime_c_api.h:8289
MIGraphX Provider Options.
Definition onnxruntime_c_api.h:840
The OrtModelEditorApi struct provides functions to create or edit an ONNX model.
Definition onnxruntime_c_api.h:7564
OpenVINO Provider Options.
Definition onnxruntime_c_api.h:879
ROCM Provider Options.
Definition onnxruntime_c_api.h:723
TensorRT Provider Options.
Definition onnxruntime_c_api.h:812
Configuration for thread pool work callbacks.
Definition onnxruntime_c_api.h:1029