ONNX Runtime
Loading...
Searching...
No Matches
onnxruntime_cxx_api.h
1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4// Summary: The Ort C++ API is a header only wrapper around the Ort C API.
5//
6// The C++ API simplifies usage by returning values directly instead of error codes, throwing exceptions on errors
7// and automatically releasing resources in the destructors. The primary purpose of C++ API is exception safety so
8// all the resources follow RAII and do not leak memory.
9//
10// Each of the C++ wrapper classes holds only a pointer to the C internal object. Treat them like smart pointers.
11// To create an empty object, pass 'nullptr' to the constructor (for example, Env e{nullptr};). However, you can't use them
12// until you assign an instance that actually holds an underlying object.
13//
14// For Ort objects only move assignment between objects is allowed, there are no copy constructors.
15// Some objects have explicit 'Clone' methods for this purpose.
16//
17// ConstXXXX types are copyable since they do not own the underlying C object, so you can pass them to functions as arguments
18// by value or by reference. ConstXXXX types are restricted to const only interfaces.
19//
20// UnownedXXXX are similar to ConstXXXX but also allow non-const interfaces.
21//
22// The lifetime of the corresponding owning object must eclipse the lifetimes of the ConstXXXX/UnownedXXXX types. They exists so you do not
23// have to fallback to C types and the API with the usual pitfalls. In general, do not use C API from your C++ code.
24
25#pragma once
26#include "onnxruntime_c_api.h"
27#include "onnxruntime_float16.h"
28
29#include <cstddef>
30#include <cstdio>
31#include <array>
32#include <memory>
33#include <stdexcept>
34#include <string>
35#include <vector>
36#include <unordered_map>
37#include <utility>
38#include <type_traits>
39
40#ifdef ORT_NO_EXCEPTIONS
41#include <iostream>
42#endif
43
47namespace Ort {
48
53struct Exception : std::exception {
54 Exception(std::string&& string, OrtErrorCode code) : message_{std::move(string)}, code_{code} {}
55
56 OrtErrorCode GetOrtErrorCode() const { return code_; }
57 const char* what() const noexcept override { return message_.c_str(); }
58
59 private:
60 std::string message_;
61 OrtErrorCode code_;
62};
63
64#ifdef ORT_NO_EXCEPTIONS
65// The #ifndef is for the very special case where the user of this library wants to define their own way of handling errors.
66// NOTE: This header expects control flow to not continue after calling ORT_CXX_API_THROW
67#ifndef ORT_CXX_API_THROW
68#define ORT_CXX_API_THROW(string, code) \
69 do { \
70 std::cerr << Ort::Exception(string, code) \
71 .what() \
72 << std::endl; \
73 abort(); \
74 } while (false)
75#endif
76#else
77#define ORT_CXX_API_THROW(string, code) \
78 throw Ort::Exception(string, code)
79#endif
80
81// This is used internally by the C++ API. This class holds the global variable that points to the OrtApi,
82// it's in a template so that we can define a global variable in a header and make
83// it transparent to the users of the API.
84template <typename T>
85struct Global {
86 static const OrtApi* api_;
87};
88
89// If macro ORT_API_MANUAL_INIT is defined, no static initialization will be performed. Instead, user must call InitApi() before using it.
90template <typename T>
91#ifdef ORT_API_MANUAL_INIT
92const OrtApi* Global<T>::api_{};
93inline void InitApi() noexcept { Global<void>::api_ = OrtGetApiBase()->GetApi(ORT_API_VERSION); }
94
95// Used by custom operator libraries that are not linked to onnxruntime. Sets the global API object, which is
96// required by C++ APIs.
97//
98// Example mycustomop.cc:
99//
100// #define ORT_API_MANUAL_INIT
101// #include <onnxruntime_cxx_api.h>
102// #undef ORT_API_MANUAL_INIT
103//
104// OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options, const OrtApiBase* api_base) {
105// Ort::InitApi(api_base->GetApi(ORT_API_VERSION));
106// // ...
107// }
108//
109inline void InitApi(const OrtApi* api) noexcept { Global<void>::api_ = api; }
110#else
111#if defined(_MSC_VER) && !defined(__clang__)
112#pragma warning(push)
113// "Global initializer calls a non-constexpr function." Therefore you can't use ORT APIs in the other global initializers.
114// Please define ORT_API_MANUAL_INIT if it conerns you.
115#pragma warning(disable : 26426)
116#endif
118#if defined(_MSC_VER) && !defined(__clang__)
119#pragma warning(pop)
120#endif
121#endif
122
124inline const OrtApi& GetApi() noexcept { return *Global<void>::api_; }
125
130std::string GetVersionString();
131
137std::string GetBuildInfoString();
138
144std::vector<std::string> GetAvailableProviders();
145
164struct Float16_t : onnxruntime_float16::Float16Impl<Float16_t> {
165 private:
171 constexpr explicit Float16_t(uint16_t v) noexcept { val = v; }
172
173 public:
174 using Base = onnxruntime_float16::Float16Impl<Float16_t>;
175
179 Float16_t() = default;
180
186 constexpr static Float16_t FromBits(uint16_t v) noexcept { return Float16_t(v); }
187
192 explicit Float16_t(float v) noexcept { val = Base::ToUint16Impl(v); }
193
198 float ToFloat() const noexcept { return Base::ToFloatImpl(); }
199
204 using Base::IsNegative;
205
210 using Base::IsNaN;
211
216 using Base::IsFinite;
217
222 using Base::IsPositiveInfinity;
223
228 using Base::IsNegativeInfinity;
229
234 using Base::IsInfinity;
235
240 using Base::IsNaNOrZero;
241
246 using Base::IsNormal;
247
252 using Base::IsSubnormal;
253
258 using Base::Abs;
259
264 using Base::Negate;
265
274 using Base::AreZero;
275
279 explicit operator float() const noexcept { return ToFloat(); }
280
281 using Base::operator==;
282 using Base::operator!=;
283 using Base::operator<;
284};
285
286static_assert(sizeof(Float16_t) == sizeof(uint16_t), "Sizes must match");
287
306struct BFloat16_t : onnxruntime_float16::BFloat16Impl<BFloat16_t> {
307 private:
315 constexpr explicit BFloat16_t(uint16_t v) noexcept { val = v; }
316
317 public:
318 using Base = onnxruntime_float16::BFloat16Impl<BFloat16_t>;
319
320 BFloat16_t() = default;
321
327 static constexpr BFloat16_t FromBits(uint16_t v) noexcept { return BFloat16_t(v); }
328
333 explicit BFloat16_t(float v) noexcept { val = Base::ToUint16Impl(v); }
334
339 float ToFloat() const noexcept { return Base::ToFloatImpl(); }
340
345 using Base::IsNegative;
346
351 using Base::IsNaN;
352
357 using Base::IsFinite;
358
363 using Base::IsPositiveInfinity;
364
369 using Base::IsNegativeInfinity;
370
375 using Base::IsInfinity;
376
381 using Base::IsNaNOrZero;
382
387 using Base::IsNormal;
388
393 using Base::IsSubnormal;
394
399 using Base::Abs;
400
405 using Base::Negate;
406
415 using Base::AreZero;
416
420 explicit operator float() const noexcept { return ToFloat(); }
421
422 // We do not have an inherited impl for the below operators
423 // as the internal class implements them a little differently
424 bool operator==(const BFloat16_t& rhs) const noexcept;
425 bool operator!=(const BFloat16_t& rhs) const noexcept { return !(*this == rhs); }
426 bool operator<(const BFloat16_t& rhs) const noexcept;
427};
428
429static_assert(sizeof(BFloat16_t) == sizeof(uint16_t), "Sizes must match");
430
437 uint8_t value;
438 constexpr Float8E4M3FN_t() noexcept : value(0) {}
439 constexpr Float8E4M3FN_t(uint8_t v) noexcept : value(v) {}
440 constexpr operator uint8_t() const noexcept { return value; }
441 // nan values are treated like any other value for operator ==, !=
442 constexpr bool operator==(const Float8E4M3FN_t& rhs) const noexcept { return value == rhs.value; };
443 constexpr bool operator!=(const Float8E4M3FN_t& rhs) const noexcept { return value != rhs.value; };
444};
445
446static_assert(sizeof(Float8E4M3FN_t) == sizeof(uint8_t), "Sizes must match");
447
454 uint8_t value;
455 constexpr Float8E4M3FNUZ_t() noexcept : value(0) {}
456 constexpr Float8E4M3FNUZ_t(uint8_t v) noexcept : value(v) {}
457 constexpr operator uint8_t() const noexcept { return value; }
458 // nan values are treated like any other value for operator ==, !=
459 constexpr bool operator==(const Float8E4M3FNUZ_t& rhs) const noexcept { return value == rhs.value; };
460 constexpr bool operator!=(const Float8E4M3FNUZ_t& rhs) const noexcept { return value != rhs.value; };
461};
462
463static_assert(sizeof(Float8E4M3FNUZ_t) == sizeof(uint8_t), "Sizes must match");
464
471 uint8_t value;
472 constexpr Float8E5M2_t() noexcept : value(0) {}
473 constexpr Float8E5M2_t(uint8_t v) noexcept : value(v) {}
474 constexpr operator uint8_t() const noexcept { return value; }
475 // nan values are treated like any other value for operator ==, !=
476 constexpr bool operator==(const Float8E5M2_t& rhs) const noexcept { return value == rhs.value; };
477 constexpr bool operator!=(const Float8E5M2_t& rhs) const noexcept { return value != rhs.value; };
478};
479
480static_assert(sizeof(Float8E5M2_t) == sizeof(uint8_t), "Sizes must match");
481
488 uint8_t value;
489 constexpr Float8E5M2FNUZ_t() noexcept : value(0) {}
490 constexpr Float8E5M2FNUZ_t(uint8_t v) noexcept : value(v) {}
491 constexpr operator uint8_t() const noexcept { return value; }
492 // nan values are treated like any other value for operator ==, !=
493 constexpr bool operator==(const Float8E5M2FNUZ_t& rhs) const noexcept { return value == rhs.value; };
494 constexpr bool operator!=(const Float8E5M2FNUZ_t& rhs) const noexcept { return value != rhs.value; };
495};
496
497static_assert(sizeof(Float8E5M2FNUZ_t) == sizeof(uint8_t), "Sizes must match");
498
499namespace detail {
500// This is used internally by the C++ API. This macro is to make it easy to generate overloaded methods for all of the various OrtRelease* functions for every Ort* type
501// This can't be done in the C API since C doesn't have function overloading.
502#define ORT_DEFINE_RELEASE(NAME) \
503 inline void OrtRelease(Ort##NAME* ptr) { GetApi().Release##NAME(ptr); }
504
505ORT_DEFINE_RELEASE(Allocator);
506ORT_DEFINE_RELEASE(MemoryInfo);
507ORT_DEFINE_RELEASE(CustomOpDomain);
508ORT_DEFINE_RELEASE(ThreadingOptions);
509ORT_DEFINE_RELEASE(Env);
510ORT_DEFINE_RELEASE(RunOptions);
511ORT_DEFINE_RELEASE(LoraAdapter);
512ORT_DEFINE_RELEASE(Session);
513ORT_DEFINE_RELEASE(SessionOptions);
514ORT_DEFINE_RELEASE(TensorTypeAndShapeInfo);
515ORT_DEFINE_RELEASE(SequenceTypeInfo);
516ORT_DEFINE_RELEASE(MapTypeInfo);
517ORT_DEFINE_RELEASE(TypeInfo);
518ORT_DEFINE_RELEASE(Value);
519ORT_DEFINE_RELEASE(ModelMetadata);
520ORT_DEFINE_RELEASE(IoBinding);
521ORT_DEFINE_RELEASE(ArenaCfg);
522ORT_DEFINE_RELEASE(Status);
523ORT_DEFINE_RELEASE(OpAttr);
524ORT_DEFINE_RELEASE(Op);
525ORT_DEFINE_RELEASE(KernelInfo);
526
527#undef ORT_DEFINE_RELEASE
528
532template <typename T>
533struct Unowned {
534 using Type = T;
535};
536
556template <typename T>
557struct Base {
558 using contained_type = T;
559
560 constexpr Base() = default;
561 constexpr explicit Base(contained_type* p) noexcept : p_{p} {}
563
564 Base(const Base&) = delete;
565 Base& operator=(const Base&) = delete;
566
567 Base(Base&& v) noexcept : p_{v.p_} { v.p_ = nullptr; }
568 Base& operator=(Base&& v) noexcept {
569 OrtRelease(p_);
570 p_ = v.release();
571 return *this;
572 }
573
574 constexpr operator contained_type*() const noexcept { return p_; }
575
579 T* p = p_;
580 p_ = nullptr;
581 return p;
582 }
583
584 protected:
586};
587
588// Undefined. For const types use Base<Unowned<const T>>
589template <typename T>
590struct Base<const T>;
591
599template <typename T>
600struct Base<Unowned<T>> {
602
603 constexpr Base() = default;
604 constexpr explicit Base(contained_type* p) noexcept : p_{p} {}
605
606 ~Base() = default;
607
608 Base(const Base&) = default;
609 Base& operator=(const Base&) = default;
610
611 Base(Base&& v) noexcept : p_{v.p_} { v.p_ = nullptr; }
612 Base& operator=(Base&& v) noexcept {
613 p_ = nullptr;
614 std::swap(p_, v.p_);
615 return *this;
616 }
617
618 constexpr operator contained_type*() const noexcept { return p_; }
619
620 protected:
622};
623
624// Light functor to release memory with OrtAllocator
627 explicit AllocatedFree(OrtAllocator* allocator)
628 : allocator_(allocator) {}
629 void operator()(void* ptr) const {
630 if (ptr) allocator_->Free(allocator_, ptr);
631 }
632};
633
634} // namespace detail
635
636struct AllocatorWithDefaultOptions;
637struct Env;
638struct TypeInfo;
639struct Value;
640struct ModelMetadata;
641
646using AllocatedStringPtr = std::unique_ptr<char, detail::AllocatedFree>;
647
652struct Status : detail::Base<OrtStatus> {
653 explicit Status(std::nullptr_t) noexcept {}
654 explicit Status(OrtStatus* status) noexcept;
655 explicit Status(const Exception&) noexcept;
656 explicit Status(const std::exception&) noexcept;
657 Status(const char* message, OrtErrorCode code) noexcept;
658 std::string GetErrorMessage() const;
660 bool IsOK() const noexcept;
661};
662
692
698struct Env : detail::Base<OrtEnv> {
699 explicit Env(std::nullptr_t) {}
700
702 Env(OrtLoggingLevel logging_level = ORT_LOGGING_LEVEL_WARNING, _In_ const char* logid = "");
703
705 Env(OrtLoggingLevel logging_level, const char* logid, OrtLoggingFunction logging_function, void* logger_param);
706
708 Env(const OrtThreadingOptions* tp_options, OrtLoggingLevel logging_level = ORT_LOGGING_LEVEL_WARNING, _In_ const char* logid = "");
709
711 Env(const OrtThreadingOptions* tp_options, OrtLoggingFunction logging_function, void* logger_param,
712 OrtLoggingLevel logging_level = ORT_LOGGING_LEVEL_WARNING, _In_ const char* logid = "");
713
715 explicit Env(OrtEnv* p) : Base<OrtEnv>{p} {}
716
719
721
722 Env& CreateAndRegisterAllocator(const OrtMemoryInfo* mem_info, const OrtArenaCfg* arena_cfg);
723
724 Env& CreateAndRegisterAllocatorV2(const std::string& provider_type, const OrtMemoryInfo* mem_info, const std::unordered_map<std::string, std::string>& options, const OrtArenaCfg* arena_cfg);
725};
726
730struct CustomOpDomain : detail::Base<OrtCustomOpDomain> {
731 explicit CustomOpDomain(std::nullptr_t) {}
732
734 explicit CustomOpDomain(const char* domain);
735
736 // This does not take ownership of the op, simply registers it.
737 void Add(const OrtCustomOp* op);
738};
739
741struct LoraAdapter : detail::Base<OrtLoraAdapter> {
743 using Base::Base;
744
745 explicit LoraAdapter(std::nullptr_t) {}
752 static LoraAdapter CreateLoraAdapter(const std::basic_string<ORTCHAR_T>& adapter_path,
753 OrtAllocator* allocator);
754
762 static LoraAdapter CreateLoraAdapterFromArray(const void* bytes, size_t num_bytes,
763 OrtAllocator* allocator);
764};
765
769struct RunOptions : detail::Base<OrtRunOptions> {
770 explicit RunOptions(std::nullptr_t) {}
772
775
778
779 RunOptions& SetRunTag(const char* run_tag);
780 const char* GetRunTag() const;
781
782 RunOptions& AddConfigEntry(const char* config_key, const char* config_value);
783
790
796
804};
805
806namespace detail {
807// Utility function that returns a SessionOption config entry key for a specific custom operator.
808// Ex: custom_op.[custom_op_name].[config]
809std::string MakeCustomOpConfigEntryKey(const char* custom_op_name, const char* config);
810} // namespace detail
811
822 CustomOpConfigs() = default;
823 ~CustomOpConfigs() = default;
828
837 CustomOpConfigs& AddConfig(const char* custom_op_name, const char* config_key, const char* config_value);
838
847 const std::unordered_map<std::string, std::string>& GetFlattenedConfigs() const;
848
849 private:
850 std::unordered_map<std::string, std::string> flat_configs_;
851};
852
858struct SessionOptions;
859
860namespace detail {
861// we separate const-only methods because passing const ptr to non-const methods
862// is only discovered when inline methods are compiled which is counter-intuitive
863template <typename T>
864struct ConstSessionOptionsImpl : Base<T> {
865 using B = Base<T>;
866 using B::B;
867
868 SessionOptions Clone() const;
869
870 std::string GetConfigEntry(const char* config_key) const;
871 bool HasConfigEntry(const char* config_key) const;
872 std::string GetConfigEntryOrDefault(const char* config_key, const std::string& def);
873};
874
875template <typename T>
876struct SessionOptionsImpl : ConstSessionOptionsImpl<T> {
877 using B = ConstSessionOptionsImpl<T>;
878 using B::B;
879
880 SessionOptionsImpl& SetIntraOpNumThreads(int intra_op_num_threads);
881 SessionOptionsImpl& SetInterOpNumThreads(int inter_op_num_threads);
882 SessionOptionsImpl& SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level);
883 SessionOptionsImpl& SetDeterministicCompute(bool value);
884
885 SessionOptionsImpl& EnableCpuMemArena();
886 SessionOptionsImpl& DisableCpuMemArena();
887
888 SessionOptionsImpl& SetOptimizedModelFilePath(const ORTCHAR_T* optimized_model_file);
889
890 SessionOptionsImpl& EnableProfiling(const ORTCHAR_T* profile_file_prefix);
891 SessionOptionsImpl& DisableProfiling();
892
893 SessionOptionsImpl& EnableOrtCustomOps();
894
895 SessionOptionsImpl& EnableMemPattern();
896 SessionOptionsImpl& DisableMemPattern();
897
898 SessionOptionsImpl& SetExecutionMode(ExecutionMode execution_mode);
899
900 SessionOptionsImpl& SetLogId(const char* logid);
901 SessionOptionsImpl& SetLogSeverityLevel(int level);
902
903 SessionOptionsImpl& Add(OrtCustomOpDomain* custom_op_domain);
904
905 SessionOptionsImpl& DisablePerSessionThreads();
906
907 SessionOptionsImpl& AddConfigEntry(const char* config_key, const char* config_value);
908
909 SessionOptionsImpl& AddInitializer(const char* name, const OrtValue* ort_val);
910 SessionOptionsImpl& AddExternalInitializers(const std::vector<std::string>& names, const std::vector<Value>& ort_values);
911 SessionOptionsImpl& AddExternalInitializersFromFilesInMemory(const std::vector<std::basic_string<ORTCHAR_T>>& external_initializer_file_names,
912 const std::vector<char*>& external_initializer_file_buffer_array,
913 const std::vector<size_t>& external_initializer_file_lengths);
914
915 SessionOptionsImpl& AppendExecutionProvider_CUDA(const OrtCUDAProviderOptions& provider_options);
916 SessionOptionsImpl& AppendExecutionProvider_CUDA_V2(const OrtCUDAProviderOptionsV2& provider_options);
917 SessionOptionsImpl& AppendExecutionProvider_ROCM(const OrtROCMProviderOptions& provider_options);
918 SessionOptionsImpl& AppendExecutionProvider_OpenVINO(const OrtOpenVINOProviderOptions& provider_options);
920 SessionOptionsImpl& AppendExecutionProvider_OpenVINO_V2(const std::unordered_map<std::string, std::string>& provider_options = {});
921 SessionOptionsImpl& AppendExecutionProvider_TensorRT(const OrtTensorRTProviderOptions& provider_options);
922 SessionOptionsImpl& AppendExecutionProvider_TensorRT_V2(const OrtTensorRTProviderOptionsV2& provider_options);
923 SessionOptionsImpl& AppendExecutionProvider_MIGraphX(const OrtMIGraphXProviderOptions& provider_options);
925 SessionOptionsImpl& AppendExecutionProvider_CANN(const OrtCANNProviderOptions& provider_options);
927 SessionOptionsImpl& AppendExecutionProvider_Dnnl(const OrtDnnlProviderOptions& provider_options);
929 SessionOptionsImpl& AppendExecutionProvider(const std::string& provider_name,
930 const std::unordered_map<std::string, std::string>& provider_options = {});
931
932 SessionOptionsImpl& SetCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn);
933 SessionOptionsImpl& SetCustomThreadCreationOptions(void* ort_custom_thread_creation_options);
934 SessionOptionsImpl& SetCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn);
935
939 SessionOptionsImpl& RegisterCustomOpsLibrary(const ORTCHAR_T* library_name, const CustomOpConfigs& custom_op_configs = {});
940
941 SessionOptionsImpl& RegisterCustomOpsUsingFunction(const char* function_name);
942
944 SessionOptionsImpl& AppendExecutionProvider_VitisAI(const std::unordered_map<std::string, std::string>& provider_options = {});
945};
946} // namespace detail
947
948using UnownedSessionOptions = detail::SessionOptionsImpl<detail::Unowned<OrtSessionOptions>>;
949using ConstSessionOptions = detail::ConstSessionOptionsImpl<detail::Unowned<const OrtSessionOptions>>;
950
954struct SessionOptions : detail::SessionOptionsImpl<OrtSessionOptions> {
955 explicit SessionOptions(std::nullptr_t) {}
957 explicit SessionOptions(OrtSessionOptions* p) : SessionOptionsImpl<OrtSessionOptions>{p} {}
960};
961
965struct ModelMetadata : detail::Base<OrtModelMetadata> {
966 explicit ModelMetadata(std::nullptr_t) {}
968
976
984
992
1000
1008
1015 std::vector<AllocatedStringPtr> GetCustomMetadataMapKeysAllocated(OrtAllocator* allocator) const;
1016
1027
1028 int64_t GetVersion() const;
1029};
1030
1031struct IoBinding;
1032
1033namespace detail {
1034
1035// we separate const-only methods because passing const ptr to non-const methods
1036// is only discovered when inline methods are compiled which is counter-intuitive
1037template <typename T>
1039 using B = Base<T>;
1040 using B::B;
1041
1042 size_t GetInputCount() const;
1043 size_t GetOutputCount() const;
1045
1054
1063
1072
1073 uint64_t GetProfilingStartTimeNs() const;
1075
1076 TypeInfo GetInputTypeInfo(size_t index) const;
1077 TypeInfo GetOutputTypeInfo(size_t index) const;
1079};
1080
1081template <typename T>
1084 using B::B;
1085
1103 std::vector<Value> Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
1104 const char* const* output_names, size_t output_count);
1105
1109 void Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
1110 const char* const* output_names, Value* output_values, size_t output_count);
1111
1112 void Run(const RunOptions& run_options, const IoBinding&);
1113
1133 void RunAsync(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
1134 const char* const* output_names, Value* output_values, size_t output_count, RunAsyncCallbackFn callback, void* user_data);
1135
1143
1155 void SetEpDynamicOptions(const char* const* keys, const char* const* values, size_t kv_len);
1156};
1157
1158} // namespace detail
1159
1162
1166struct Session : detail::SessionImpl<OrtSession> {
1167 explicit Session(std::nullptr_t) {}
1168 Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options);
1169 Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options,
1170 OrtPrepackedWeightsContainer* prepacked_weights_container);
1171 Session(const Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options);
1172 Session(const Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options,
1173 OrtPrepackedWeightsContainer* prepacked_weights_container);
1174
1175 ConstSession GetConst() const { return ConstSession{this->p_}; }
1176 UnownedSession GetUnowned() const { return UnownedSession{this->p_}; }
1177};
1178
1179namespace detail {
1180template <typename T>
1182 using B = Base<T>;
1183 using B::B;
1184
1185 std::string GetAllocatorName() const;
1187 int GetDeviceId() const;
1190
1191 template <typename U>
1192 bool operator==(const MemoryInfoImpl<U>& o) const;
1193};
1194} // namespace detail
1195
1196// Const object holder that does not own the underlying object
1198
1202struct MemoryInfo : detail::MemoryInfoImpl<OrtMemoryInfo> {
1204 explicit MemoryInfo(std::nullptr_t) {}
1205 explicit MemoryInfo(OrtMemoryInfo* p) : MemoryInfoImpl<OrtMemoryInfo>{p} {}
1206 MemoryInfo(const char* name, OrtAllocatorType type, int id, OrtMemType mem_type);
1207 ConstMemoryInfo GetConst() const { return ConstMemoryInfo{this->p_}; }
1208};
1209
1210namespace detail {
1211template <typename T>
1213 using B = Base<T>;
1214 using B::B;
1215
1217 size_t GetElementCount() const;
1218
1219 size_t GetDimensionsCount() const;
1220
1225 [[deprecated("use GetShape()")]] void GetDimensions(int64_t* values, size_t values_count) const;
1226
1227 void GetSymbolicDimensions(const char** values, size_t values_count) const;
1228
1229 std::vector<int64_t> GetShape() const;
1230};
1231
1232} // namespace detail
1233
1235
1240 explicit TensorTypeAndShapeInfo(std::nullptr_t) {}
1241 explicit TensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* p) : TensorTypeAndShapeInfoImpl{p} {}
1243};
1244
1245namespace detail {
1246template <typename T>
1248 using B = Base<T>;
1249 using B::B;
1251};
1252
1253} // namespace detail
1254
1256
1260struct SequenceTypeInfo : detail::SequenceTypeInfoImpl<OrtSequenceTypeInfo> {
1261 explicit SequenceTypeInfo(std::nullptr_t) {}
1262 explicit SequenceTypeInfo(OrtSequenceTypeInfo* p) : SequenceTypeInfoImpl<OrtSequenceTypeInfo>{p} {}
1264};
1265
1266namespace detail {
1267template <typename T>
1269 using B = Base<T>;
1270 using B::B;
1272};
1273
1274} // namespace detail
1275
1276// This is always owned by the TypeInfo and can only be obtained from it.
1278
1279namespace detail {
1280template <typename T>
1287
1288} // namespace detail
1289
1291
1295struct MapTypeInfo : detail::MapTypeInfoImpl<OrtMapTypeInfo> {
1296 explicit MapTypeInfo(std::nullptr_t) {}
1297 explicit MapTypeInfo(OrtMapTypeInfo* p) : MapTypeInfoImpl<OrtMapTypeInfo>{p} {}
1298 ConstMapTypeInfo GetConst() const { return ConstMapTypeInfo{this->p_}; }
1299};
1300
1301namespace detail {
1302template <typename T>
1314} // namespace detail
1315
1321
1326struct TypeInfo : detail::TypeInfoImpl<OrtTypeInfo> {
1327 explicit TypeInfo(std::nullptr_t) {}
1328 explicit TypeInfo(OrtTypeInfo* p) : TypeInfoImpl<OrtTypeInfo>{p} {}
1329
1330 ConstTypeInfo GetConst() const { return ConstTypeInfo{this->p_}; }
1331};
1332
1333namespace detail {
1334// This structure is used to feed sparse tensor values
1335// information for use with FillSparseTensor<Format>() API
1336// if the data type for the sparse tensor values is numeric
1337// use data.p_data, otherwise, use data.str pointer to feed
1338// values. data.str is an array of const char* that are zero terminated.
1339// number of strings in the array must match shape size.
1340// For fully sparse tensors use shape {0} and set p_data/str
1341// to nullptr.
1343 const int64_t* values_shape;
1345 union {
1346 const void* p_data;
1347 const char** str;
1348 } data;
1349};
1350
1351// Provides a way to pass shape in a single
1352// argument
1353struct Shape {
1354 const int64_t* shape;
1356};
1357
1358template <typename T>
1360 using B = Base<T>;
1361 using B::B;
1362
1366 template <typename R>
1367 void GetOpaqueData(const char* domain, const char* type_name, R&) const;
1368
1369 bool IsTensor() const;
1370 bool HasValue() const;
1371
1372 size_t GetCount() const; // If a non tensor, returns 2 for map and N for sequence, where N is the number of elements
1373 Value GetValue(int index, OrtAllocator* allocator) const;
1374
1382
1397 void GetStringTensorContent(void* buffer, size_t buffer_length, size_t* offsets, size_t offsets_count) const;
1398
1405 template <typename R>
1406 const R* GetTensorData() const;
1407
1412 const void* GetTensorRawData() const;
1413
1421
1429
1435
1444 void GetStringTensorElement(size_t buffer_length, size_t element_index, void* buffer) const;
1445
1452 std::string GetStringTensorElement(size_t element_index) const;
1453
1460 size_t GetStringTensorElementLength(size_t element_index) const;
1461
1462#if !defined(DISABLE_SPARSE_TENSORS)
1470
1477
1486
1496 template <typename R>
1497 const R* GetSparseTensorIndicesData(OrtSparseIndicesFormat indices_format, size_t& num_indices) const;
1498
1503 bool IsSparseTensor() const;
1504
1513 template <typename R>
1514 const R* GetSparseTensorValues() const;
1515
1516#endif
1517};
1518
1519template <typename T>
1522 using B::B;
1523
1529 template <typename R>
1531
1537
1539 // Obtain a reference to an element of data at the location specified
1545 template <typename R>
1546 R& At(const std::vector<int64_t>& location);
1547
1553 void FillStringTensor(const char* const* s, size_t s_len);
1554
1560 void FillStringTensorElement(const char* s, size_t index);
1561
1574 char* GetResizedStringTensorElementBuffer(size_t index, size_t buffer_length);
1575
1576#if !defined(DISABLE_SPARSE_TENSORS)
1585 void UseCooIndices(int64_t* indices_data, size_t indices_num);
1586
1597 void UseCsrIndices(int64_t* inner_data, size_t inner_num, int64_t* outer_data, size_t outer_num);
1598
1607 void UseBlockSparseIndices(const Shape& indices_shape, int32_t* indices_data);
1608
1618 void FillSparseTensorCoo(const OrtMemoryInfo* data_mem_info, const OrtSparseValuesParam& values_param,
1619 const int64_t* indices_data, size_t indices_num);
1620
1632 void FillSparseTensorCsr(const OrtMemoryInfo* data_mem_info,
1633 const OrtSparseValuesParam& values,
1634 const int64_t* inner_indices_data, size_t inner_indices_num,
1635 const int64_t* outer_indices_data, size_t outer_indices_num);
1636
1647 const OrtSparseValuesParam& values,
1648 const Shape& indices_shape,
1649 const int32_t* indices_data);
1650
1651#endif
1652};
1653
1654} // namespace detail
1655
1658
1662struct Value : detail::ValueImpl<OrtValue> {
1666
1667 explicit Value(std::nullptr_t) {}
1668 explicit Value(OrtValue* p) : Base{p} {}
1669 Value(Value&&) = default;
1670 Value& operator=(Value&&) = default;
1671
1672 ConstValue GetConst() const { return ConstValue{this->p_}; }
1673 UnownedValue GetUnowned() const { return UnownedValue{this->p_}; }
1674
1683 template <typename T>
1684 static Value CreateTensor(const OrtMemoryInfo* info, T* p_data, size_t p_data_element_count, const int64_t* shape, size_t shape_len);
1685
1695 static Value CreateTensor(const OrtMemoryInfo* info, void* p_data, size_t p_data_byte_count, const int64_t* shape, size_t shape_len,
1697
1709 template <typename T>
1710 static Value CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len);
1711
1723 static Value CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type);
1724
1733 static Value CreateMap(const Value& keys, const Value& values);
1734
1742 static Value CreateSequence(const std::vector<Value>& values);
1743
1752 template <typename T>
1753 static Value CreateOpaque(const char* domain, const char* type_name, const T& value);
1754
1755#if !defined(DISABLE_SPARSE_TENSORS)
1766 template <typename T>
1767 static Value CreateSparseTensor(const OrtMemoryInfo* info, T* p_data, const Shape& dense_shape,
1768 const Shape& values_shape);
1769
1786 static Value CreateSparseTensor(const OrtMemoryInfo* info, void* p_data, const Shape& dense_shape,
1787 const Shape& values_shape, ONNXTensorElementDataType type);
1788
1798 template <typename T>
1799 static Value CreateSparseTensor(OrtAllocator* allocator, const Shape& dense_shape);
1800
1812 static Value CreateSparseTensor(OrtAllocator* allocator, const Shape& dense_shape, ONNXTensorElementDataType type);
1813
1814#endif // !defined(DISABLE_SPARSE_TENSORS)
1815};
1816
1824 MemoryAllocation(OrtAllocator* allocator, void* p, size_t size);
1829 MemoryAllocation& operator=(MemoryAllocation&&) noexcept;
1830
1831 void* get() { return p_; }
1832 size_t size() const { return size_; }
1833
1834 private:
1835 OrtAllocator* allocator_;
1836 void* p_;
1837 size_t size_;
1838};
1839
1840namespace detail {
1841template <typename T>
1842struct AllocatorImpl : Base<T> {
1843 using B = Base<T>;
1844 using B::B;
1845
1846 void* Alloc(size_t size);
1847 MemoryAllocation GetAllocation(size_t size);
1848 void Free(void* p);
1849 ConstMemoryInfo GetInfo() const;
1850};
1851
1852} // namespace detail
1853
1857struct AllocatorWithDefaultOptions : detail::AllocatorImpl<detail::Unowned<OrtAllocator>> {
1858 explicit AllocatorWithDefaultOptions(std::nullptr_t) {}
1860};
1861
1865struct Allocator : detail::AllocatorImpl<OrtAllocator> {
1866 explicit Allocator(std::nullptr_t) {}
1867 Allocator(const Session& session, const OrtMemoryInfo*);
1868};
1869
1870using UnownedAllocator = detail::AllocatorImpl<detail::Unowned<OrtAllocator>>;
1871
1872namespace detail {
1873namespace binding_utils {
1874// Bring these out of template
1875std::vector<std::string> GetOutputNamesHelper(const OrtIoBinding* binding, OrtAllocator*);
1876std::vector<Value> GetOutputValuesHelper(const OrtIoBinding* binding, OrtAllocator*);
1877} // namespace binding_utils
1878
1879template <typename T>
1881 using B = Base<T>;
1882 using B::B;
1883
1884 std::vector<std::string> GetOutputNames() const;
1885 std::vector<std::string> GetOutputNames(OrtAllocator*) const;
1886 std::vector<Value> GetOutputValues() const;
1887 std::vector<Value> GetOutputValues(OrtAllocator*) const;
1888};
1889
1890template <typename T>
1893 using B::B;
1894
1895 void BindInput(const char* name, const Value&);
1896 void BindOutput(const char* name, const Value&);
1897 void BindOutput(const char* name, const OrtMemoryInfo*);
1902};
1903
1904} // namespace detail
1905
1908
1912struct IoBinding : detail::IoBindingImpl<OrtIoBinding> {
1913 explicit IoBinding(std::nullptr_t) {}
1914 explicit IoBinding(Session& session);
1915 ConstIoBinding GetConst() const { return ConstIoBinding{this->p_}; }
1916 UnownedIoBinding GetUnowned() const { return UnownedIoBinding{this->p_}; }
1917};
1918
1923struct ArenaCfg : detail::Base<OrtArenaCfg> {
1924 explicit ArenaCfg(std::nullptr_t) {}
1933 ArenaCfg(size_t max_mem, int arena_extend_strategy, int initial_chunk_size_bytes, int max_dead_bytes_per_chunk);
1934};
1935
1936//
1937// Custom OPs (only needed to implement custom OPs)
1938//
1939
1943struct OpAttr : detail::Base<OrtOpAttr> {
1944 OpAttr(const char* name, const void* data, int len, OrtOpAttrType type);
1945};
1946
1955#define ORT_CXX_LOG(logger, message_severity, message) \
1956 do { \
1957 if (message_severity >= logger.GetLoggingSeverityLevel()) { \
1958 Ort::ThrowOnError(logger.LogMessage(message_severity, ORT_FILE, __LINE__, \
1959 static_cast<const char*>(__FUNCTION__), message)); \
1960 } \
1961 } while (false)
1962
1971#define ORT_CXX_LOG_NOEXCEPT(logger, message_severity, message) \
1972 do { \
1973 if (message_severity >= logger.GetLoggingSeverityLevel()) { \
1974 static_cast<void>(logger.LogMessage(message_severity, ORT_FILE, __LINE__, \
1975 static_cast<const char*>(__FUNCTION__), message)); \
1976 } \
1977 } while (false)
1978
1990#define ORT_CXX_LOGF(logger, message_severity, /*format,*/...) \
1991 do { \
1992 if (message_severity >= logger.GetLoggingSeverityLevel()) { \
1993 Ort::ThrowOnError(logger.LogFormattedMessage(message_severity, ORT_FILE, __LINE__, \
1994 static_cast<const char*>(__FUNCTION__), __VA_ARGS__)); \
1995 } \
1996 } while (false)
1997
2009#define ORT_CXX_LOGF_NOEXCEPT(logger, message_severity, /*format,*/...) \
2010 do { \
2011 if (message_severity >= logger.GetLoggingSeverityLevel()) { \
2012 static_cast<void>(logger.LogFormattedMessage(message_severity, ORT_FILE, __LINE__, \
2013 static_cast<const char*>(__FUNCTION__), __VA_ARGS__)); \
2014 } \
2015 } while (false)
2016
2027struct Logger {
2031 Logger() = default;
2032
2036 explicit Logger(std::nullptr_t) {}
2037
2044 explicit Logger(const OrtLogger* logger);
2045
2046 ~Logger() = default;
2047
2048 Logger(const Logger&) = default;
2049 Logger& operator=(const Logger&) = default;
2050
2051 Logger(Logger&& v) noexcept = default;
2052 Logger& operator=(Logger&& v) noexcept = default;
2053
2060
2073 Status LogMessage(OrtLoggingLevel log_severity_level, const ORTCHAR_T* file_path, int line_number,
2074 const char* func_name, const char* message) const noexcept;
2075
2090 template <typename... Args>
2091 Status LogFormattedMessage(OrtLoggingLevel log_severity_level, const ORTCHAR_T* file_path, int line_number,
2092 const char* func_name, const char* format, Args&&... args) const noexcept;
2093
2094 private:
2095 const OrtLogger* logger_{};
2096 OrtLoggingLevel cached_severity_level_{};
2097};
2098
2107 size_t GetInputCount() const;
2108 size_t GetOutputCount() const;
2109 // If input is optional and is not present, the method returns en empty ConstValue
2110 // which can be compared to nullptr.
2111 ConstValue GetInput(size_t index) const;
2112 // If outout is optional and is not present, the method returns en empty UnownedValue
2113 // which can be compared to nullptr.
2114 UnownedValue GetOutput(size_t index, const int64_t* dim_values, size_t dim_count) const;
2115 UnownedValue GetOutput(size_t index, const std::vector<int64_t>& dims) const;
2116 void* GetGPUComputeStream() const;
2118 OrtAllocator* GetAllocator(const OrtMemoryInfo& memory_info) const;
2119 OrtKernelContext* GetOrtKernelContext() const { return ctx_; }
2120 void ParallelFor(void (*fn)(void*, size_t), size_t total, size_t num_batch, void* usr_data) const;
2121
2122 private:
2123 OrtKernelContext* ctx_;
2124};
2125
2126struct KernelInfo;
2127
2128namespace detail {
2129namespace attr_utils {
2130void GetAttr(const OrtKernelInfo* p, const char* name, float&);
2131void GetAttr(const OrtKernelInfo* p, const char* name, int64_t&);
2132void GetAttr(const OrtKernelInfo* p, const char* name, std::string&);
2133void GetAttrs(const OrtKernelInfo* p, const char* name, std::vector<float>&);
2134void GetAttrs(const OrtKernelInfo* p, const char* name, std::vector<int64_t>&);
2135} // namespace attr_utils
2136
2137template <typename T>
2138struct KernelInfoImpl : Base<T> {
2139 using B = Base<T>;
2140 using B::B;
2141
2142 KernelInfo Copy() const;
2143
2144 template <typename R> // R is only implemented for float, int64_t, and string
2145 R GetAttribute(const char* name) const {
2146 R val;
2147 attr_utils::GetAttr(this->p_, name, val);
2148 return val;
2149 }
2150
2151 template <typename R> // R is only implemented for std::vector<float>, std::vector<int64_t>
2152 std::vector<R> GetAttributes(const char* name) const {
2153 std::vector<R> result;
2154 attr_utils::GetAttrs(this->p_, name, result);
2155 return result;
2156 }
2157
2158 Value GetTensorAttribute(const char* name, OrtAllocator* allocator) const;
2159
2160 size_t GetInputCount() const;
2161 size_t GetOutputCount() const;
2162
2163 std::string GetInputName(size_t index) const;
2164 std::string GetOutputName(size_t index) const;
2165
2166 TypeInfo GetInputTypeInfo(size_t index) const;
2167 TypeInfo GetOutputTypeInfo(size_t index) const;
2168
2169 ConstValue GetTensorConstantInput(size_t index, int* is_constant) const;
2170
2171 std::string GetNodeName() const;
2172 Logger GetLogger() const;
2173};
2174
2175} // namespace detail
2176
2177using ConstKernelInfo = detail::KernelInfoImpl<detail::Unowned<const OrtKernelInfo>>;
2178
2185struct KernelInfo : detail::KernelInfoImpl<OrtKernelInfo> {
2186 explicit KernelInfo(std::nullptr_t) {}
2187 explicit KernelInfo(OrtKernelInfo* info);
2188 ConstKernelInfo GetConst() const { return ConstKernelInfo{this->p_}; }
2189};
2190
2194struct Op : detail::Base<OrtOp> {
2195 explicit Op(std::nullptr_t) {}
2196
2197 explicit Op(OrtOp*);
2198
2199 static Op Create(const OrtKernelInfo* info, const char* op_name, const char* domain,
2200 int version, const char** type_constraint_names,
2201 const ONNXTensorElementDataType* type_constraint_values,
2202 size_t type_constraint_count,
2203 const OpAttr* attr_values,
2204 size_t attr_count,
2205 size_t input_count, size_t output_count);
2206
2207 void Invoke(const OrtKernelContext* context,
2208 const Value* input_values,
2209 size_t input_count,
2210 Value* output_values,
2211 size_t output_count);
2212
2213 // For easier refactoring
2214 void Invoke(const OrtKernelContext* context,
2215 const OrtValue* const* input_values,
2216 size_t input_count,
2217 OrtValue* const* output_values,
2218 size_t output_count);
2219};
2220
2226 SymbolicInteger(int64_t i) : i_(i), is_int_(true) {};
2227 SymbolicInteger(const char* s) : s_(s), is_int_(false) {};
2230
2233
2234 bool operator==(const SymbolicInteger& dim) const {
2235 if (is_int_ == dim.is_int_) {
2236 if (is_int_) {
2237 return i_ == dim.i_;
2238 } else {
2239 return std::string{s_} == std::string{dim.s_};
2240 }
2241 }
2242 return false;
2243 }
2244
2245 bool IsInt() const { return is_int_; }
2246 int64_t AsInt() const { return i_; }
2247 const char* AsSym() const { return s_; }
2248
2249 static constexpr int INVALID_INT_DIM = -2;
2250
2251 private:
2252 union {
2253 int64_t i_;
2254 const char* s_;
2255 };
2256 bool is_int_;
2257 };
2258
2259 using Shape = std::vector<SymbolicInteger>;
2260
2262
2263 const Shape& GetInputShape(size_t indice) const { return input_shapes_.at(indice); }
2264
2265 size_t GetInputCount() const { return input_shapes_.size(); }
2266
2268
2269 int64_t GetAttrInt(const char* attr_name);
2270
2271 using Ints = std::vector<int64_t>;
2272 Ints GetAttrInts(const char* attr_name);
2273
2274 float GetAttrFloat(const char* attr_name);
2275
2276 using Floats = std::vector<float>;
2277 Floats GetAttrFloats(const char* attr_name);
2278
2279 std::string GetAttrString(const char* attr_name);
2280
2281 using Strings = std::vector<std::string>;
2282 Strings GetAttrStrings(const char* attr_name);
2283
2284 private:
2285 const OrtOpAttr* GetAttrHdl(const char* attr_name) const;
2286 const OrtApi* ort_api_;
2288 std::vector<Shape> input_shapes_;
2289};
2290
2292
2293#define MAX_CUSTOM_OP_END_VER (1UL << 31) - 1
2294
2295template <typename TOp, typename TKernel, bool WithStatus = false>
2299 OrtCustomOp::GetName = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetName(); };
2300
2301 OrtCustomOp::GetExecutionProviderType = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetExecutionProviderType(); };
2302
2303 OrtCustomOp::GetInputTypeCount = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetInputTypeCount(); };
2304 OrtCustomOp::GetInputType = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetInputType(index); };
2305 OrtCustomOp::GetInputMemoryType = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetInputMemoryType(index); };
2306
2307 OrtCustomOp::GetOutputTypeCount = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetOutputTypeCount(); };
2308 OrtCustomOp::GetOutputType = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetOutputType(index); };
2309
2310#if defined(_MSC_VER) && !defined(__clang__)
2311#pragma warning(push)
2312#pragma warning(disable : 26409)
2313#endif
2314 OrtCustomOp::KernelDestroy = [](void* op_kernel) { delete static_cast<TKernel*>(op_kernel); };
2315#if defined(_MSC_VER) && !defined(__clang__)
2316#pragma warning(pop)
2317#endif
2318 OrtCustomOp::GetInputCharacteristic = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetInputCharacteristic(index); };
2319 OrtCustomOp::GetOutputCharacteristic = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetOutputCharacteristic(index); };
2320
2321 OrtCustomOp::GetVariadicInputMinArity = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetVariadicInputMinArity(); };
2322 OrtCustomOp::GetVariadicInputHomogeneity = [](const OrtCustomOp* this_) { return static_cast<int>(static_cast<const TOp*>(this_)->GetVariadicInputHomogeneity()); };
2323 OrtCustomOp::GetVariadicOutputMinArity = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetVariadicOutputMinArity(); };
2324 OrtCustomOp::GetVariadicOutputHomogeneity = [](const OrtCustomOp* this_) { return static_cast<int>(static_cast<const TOp*>(this_)->GetVariadicOutputHomogeneity()); };
2325#ifdef __cpp_if_constexpr
2326 if constexpr (WithStatus) {
2327#else
2328 if (WithStatus) {
2329#endif
2330 OrtCustomOp::CreateKernelV2 = [](const OrtCustomOp* this_, const OrtApi* api, const OrtKernelInfo* info, void** op_kernel) -> OrtStatusPtr {
2331 return static_cast<const TOp*>(this_)->CreateKernelV2(*api, info, op_kernel);
2332 };
2333 OrtCustomOp::KernelComputeV2 = [](void* op_kernel, OrtKernelContext* context) -> OrtStatusPtr {
2334 return static_cast<TKernel*>(op_kernel)->ComputeV2(context);
2335 };
2336 } else {
2339
2340 OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* api, const OrtKernelInfo* info) { return static_cast<const TOp*>(this_)->CreateKernel(*api, info); };
2341 OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) {
2342 static_cast<TKernel*>(op_kernel)->Compute(context);
2343 };
2344 }
2345
2346 SetShapeInferFn<TOp>(0);
2347
2348 OrtCustomOp::GetStartVersion = [](const OrtCustomOp* this_) {
2349 return static_cast<const TOp*>(this_)->start_ver_;
2350 };
2351
2352 OrtCustomOp::GetEndVersion = [](const OrtCustomOp* this_) {
2353 return static_cast<const TOp*>(this_)->end_ver_;
2354 };
2355
2358 OrtCustomOp::GetAliasMap = nullptr;
2360 }
2361
2362 // Default implementation of GetExecutionProviderType that returns nullptr to default to the CPU provider
2363 const char* GetExecutionProviderType() const { return nullptr; }
2364
2365 // Default implementations of GetInputCharacteristic() and GetOutputCharacteristic() below
2366 // (inputs and outputs are required by default)
2368 return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED;
2369 }
2370
2372 return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED;
2373 }
2374
2375 // Default implemention of GetInputMemoryType() that returns OrtMemTypeDefault
2376 OrtMemType GetInputMemoryType(size_t /*index*/) const {
2377 return OrtMemTypeDefault;
2378 }
2379
2380 // Default implementation of GetVariadicInputMinArity() returns 1 to specify that a variadic input
2381 // should expect at least 1 argument.
2383 return 1;
2384 }
2385
2386 // Default implementation of GetVariadicInputHomegeneity() returns true to specify that all arguments
2387 // to a variadic input should be of the same type.
2389 return true;
2390 }
2391
2392 // Default implementation of GetVariadicOutputMinArity() returns 1 to specify that a variadic output
2393 // should produce at least 1 output value.
2395 return 1;
2396 }
2397
2398 // Default implementation of GetVariadicOutputHomegeneity() returns true to specify that all output values
2399 // produced by a variadic output should be of the same type.
2401 return true;
2402 }
2403
2404 // Declare list of session config entries used by this Custom Op.
2405 // Implement this function in order to get configs from CustomOpBase::GetSessionConfigs().
2406 // This default implementation returns an empty vector of config entries.
2407 std::vector<std::string> GetSessionConfigKeys() const {
2408 return std::vector<std::string>{};
2409 }
2410
2411 template <typename C>
2412 decltype(&C::InferOutputShape) SetShapeInferFn(decltype(&C::InferOutputShape)) {
2414 ShapeInferContext ctx(&GetApi(), ort_ctx);
2415 return C::InferOutputShape(ctx);
2416 };
2417 return {};
2418 }
2419
2420 template <typename C>
2424
2425 protected:
2426 // Helper function that returns a map of session config entries specified by CustomOpBase::GetSessionConfigKeys.
2427 void GetSessionConfigs(std::unordered_map<std::string, std::string>& out, ConstSessionOptions options) const;
2428
2429 int start_ver_ = 1;
2430 int end_ver_ = MAX_CUSTOM_OP_END_VER;
2431};
2432
2433} // namespace Ort
2434
2435#include "onnxruntime_cxx_inline.h"
struct OrtMemoryInfo OrtMemoryInfo
Definition onnxruntime_c_api.h:282
struct OrtKernelInfo OrtKernelInfo
Definition onnxruntime_c_api.h:369
OrtLoggingLevel
Logging severity levels.
Definition onnxruntime_c_api.h:237
OrtMemoryInfoDeviceType
This mimics OrtDevice type constants so they can be returned in the API.
Definition onnxruntime_c_api.h:393
struct OrtShapeInferContext OrtShapeInferContext
Definition onnxruntime_c_api.h:306
void(* OrtLoggingFunction)(void *param, OrtLoggingLevel severity, const char *category, const char *logid, const char *code_location, const char *message)
Definition onnxruntime_c_api.h:334
void(* OrtCustomJoinThreadFn)(OrtCustomThreadHandle ort_custom_thread_handle)
Custom thread join function.
Definition onnxruntime_c_api.h:717
OrtCustomOpInputOutputCharacteristic
Definition onnxruntime_c_api.h:4771
struct OrtTensorRTProviderOptionsV2 OrtTensorRTProviderOptionsV2
Definition onnxruntime_c_api.h:299
struct OrtOpAttr OrtOpAttr
Definition onnxruntime_c_api.h:304
struct OrtThreadingOptions OrtThreadingOptions
Definition onnxruntime_c_api.h:296
struct OrtSequenceTypeInfo OrtSequenceTypeInfo
Definition onnxruntime_c_api.h:290
struct OrtDnnlProviderOptions OrtDnnlProviderOptions
Definition onnxruntime_c_api.h:302
OrtSparseIndicesFormat
Definition onnxruntime_c_api.h:226
struct OrtPrepackedWeightsContainer OrtPrepackedWeightsContainer
Definition onnxruntime_c_api.h:298
struct OrtCustomOpDomain OrtCustomOpDomain
Definition onnxruntime_c_api.h:293
struct OrtIoBinding OrtIoBinding
Definition onnxruntime_c_api.h:283
OrtAllocatorType
Definition onnxruntime_c_api.h:375
struct OrtOp OrtOp
Definition onnxruntime_c_api.h:303
struct OrtModelMetadata OrtModelMetadata
Definition onnxruntime_c_api.h:294
struct OrtTypeInfo OrtTypeInfo
Definition onnxruntime_c_api.h:287
struct OrtTensorTypeAndShapeInfo OrtTensorTypeAndShapeInfo
Definition onnxruntime_c_api.h:288
struct OrtCUDAProviderOptionsV2 OrtCUDAProviderOptionsV2
Definition onnxruntime_c_api.h:300
struct OrtKernelContext OrtKernelContext
Definition onnxruntime_c_api.h:371
struct OrtCANNProviderOptions OrtCANNProviderOptions
Definition onnxruntime_c_api.h:301
void(* RunAsyncCallbackFn)(void *user_data, OrtValue **outputs, size_t num_outputs, OrtStatusPtr status)
Callback function for RunAsync.
Definition onnxruntime_c_api.h:728
struct OrtSessionOptions OrtSessionOptions
Definition onnxruntime_c_api.h:292
struct OrtValue OrtValue
Definition onnxruntime_c_api.h:285
GraphOptimizationLevel
Graph optimization level.
Definition onnxruntime_c_api.h:343
OrtStatus * OrtStatusPtr
Definition onnxruntime_c_api.h:312
OrtMemType
Memory types for allocated memory, execution provider specific types should be extended in each provi...
Definition onnxruntime_c_api.h:384
OrtSparseFormat
Definition onnxruntime_c_api.h:218
ONNXType
Definition onnxruntime_c_api.h:206
struct OrtEnv OrtEnv
Definition onnxruntime_c_api.h:280
OrtErrorCode
Definition onnxruntime_c_api.h:245
struct OrtStatus OrtStatus
Definition onnxruntime_c_api.h:281
#define ORT_API_VERSION
The API version defined in this header.
Definition onnxruntime_c_api.h:41
struct OrtLogger OrtLogger
Definition onnxruntime_c_api.h:305
struct OrtMapTypeInfo OrtMapTypeInfo
Definition onnxruntime_c_api.h:289
struct OrtArenaCfg OrtArenaCfg
Definition onnxruntime_c_api.h:297
ExecutionMode
Definition onnxruntime_c_api.h:350
OrtOpAttrType
Definition onnxruntime_c_api.h:260
OrtCustomThreadHandle(* OrtCustomCreateThreadFn)(void *ort_custom_thread_creation_options, OrtThreadWorkerFn ort_thread_worker_fn, void *ort_worker_fn_param)
Ort custom thread creation function.
Definition onnxruntime_c_api.h:710
ONNXTensorElementDataType
Definition onnxruntime_c_api.h:177
const OrtApiBase * OrtGetApiBase(void)
The Onnxruntime library's entry point to access the C API.
@ ORT_LOGGING_LEVEL_WARNING
Warning messages.
Definition onnxruntime_c_api.h:240
@ OrtMemTypeDefault
The default allocator for execution provider.
Definition onnxruntime_c_api.h:388
@ ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT
Definition onnxruntime_c_api.h:179
std::vector< Value > GetOutputValuesHelper(const OrtIoBinding *binding, OrtAllocator *)
std::vector< std::string > GetOutputNamesHelper(const OrtIoBinding *binding, OrtAllocator *)
void OrtRelease(OrtAllocator *ptr)
Definition onnxruntime_cxx_api.h:505
std::string MakeCustomOpConfigEntryKey(const char *custom_op_name, const char *config)
All C++ Onnxruntime APIs are defined inside this namespace.
Definition onnxruntime_cxx_api.h:47
std::unique_ptr< char, detail::AllocatedFree > AllocatedStringPtr
unique_ptr typedef used to own strings allocated by OrtAllocators and release them at the end of the ...
Definition onnxruntime_cxx_api.h:646
detail::ConstSessionOptionsImpl< detail::Unowned< const OrtSessionOptions > > ConstSessionOptions
Definition onnxruntime_cxx_api.h:949
detail::KernelInfoImpl< detail::Unowned< const OrtKernelInfo > > ConstKernelInfo
Definition onnxruntime_cxx_api.h:2177
const OrtApi & GetApi() noexcept
This returns a reference to the OrtApi interface in use.
Definition onnxruntime_cxx_api.h:124
detail::AllocatorImpl< detail::Unowned< OrtAllocator > > UnownedAllocator
Definition onnxruntime_cxx_api.h:1870
detail::SessionOptionsImpl< detail::Unowned< OrtSessionOptions > > UnownedSessionOptions
Definition onnxruntime_cxx_api.h:948
std::string GetBuildInfoString()
This function returns the onnxruntime build information: including git branch, git commit id,...
std::string GetVersionString()
This function returns the onnxruntime version string.
std::vector< std::string > GetAvailableProviders()
This is a C++ wrapper for OrtApi::GetAvailableProviders() and returns a vector of strings representin...
Ort::Status(*)(Ort::ShapeInferContext &) ShapeInferFn
Definition onnxruntime_cxx_api.h:2291
Wrapper around OrtAllocator.
Definition onnxruntime_cxx_api.h:1865
Allocator(const Session &session, const OrtMemoryInfo *)
Allocator(std::nullptr_t)
Convenience to create a class member and then replace with an instance.
Definition onnxruntime_cxx_api.h:1866
Wrapper around OrtAllocator default instance that is owned by Onnxruntime.
Definition onnxruntime_cxx_api.h:1857
AllocatorWithDefaultOptions(std::nullptr_t)
Convenience to create a class member and then replace with an instance.
Definition onnxruntime_cxx_api.h:1858
it is a structure that represents the configuration of an arena based allocator
Definition onnxruntime_cxx_api.h:1923
ArenaCfg(std::nullptr_t)
Create an empty ArenaCfg object, must be assigned a valid one to be used.
Definition onnxruntime_cxx_api.h:1924
ArenaCfg(size_t max_mem, int arena_extend_strategy, int initial_chunk_size_bytes, int max_dead_bytes_per_chunk)
bfloat16 (Brain Floating Point) data type
Definition onnxruntime_cxx_api.h:306
bool operator==(const BFloat16_t &rhs) const noexcept
onnxruntime_float16::BFloat16Impl< BFloat16_t > Base
Definition onnxruntime_cxx_api.h:318
BFloat16_t()=default
static constexpr BFloat16_t FromBits(uint16_t v) noexcept
Explicit conversion to uint16_t representation of bfloat16.
Definition onnxruntime_cxx_api.h:327
bool operator!=(const BFloat16_t &rhs) const noexcept
Definition onnxruntime_cxx_api.h:425
BFloat16_t(float v) noexcept
__ctor from float. Float is converted into bfloat16 16-bit representation.
Definition onnxruntime_cxx_api.h:333
float ToFloat() const noexcept
Converts bfloat16 to float.
Definition onnxruntime_cxx_api.h:339
bool operator<(const BFloat16_t &rhs) const noexcept
Definition onnxruntime_cxx_api.h:2296
OrtCustomOpInputOutputCharacteristic GetOutputCharacteristic(size_t) const
Definition onnxruntime_cxx_api.h:2371
OrtCustomOpInputOutputCharacteristic GetInputCharacteristic(size_t) const
Definition onnxruntime_cxx_api.h:2367
OrtMemType GetInputMemoryType(size_t) const
Definition onnxruntime_cxx_api.h:2376
std::vector< std::string > GetSessionConfigKeys() const
Definition onnxruntime_cxx_api.h:2407
bool GetVariadicInputHomogeneity() const
Definition onnxruntime_cxx_api.h:2388
int GetVariadicInputMinArity() const
Definition onnxruntime_cxx_api.h:2382
void SetShapeInferFn(...)
Definition onnxruntime_cxx_api.h:2421
CustomOpBase()
Definition onnxruntime_cxx_api.h:2297
bool GetVariadicOutputHomogeneity() const
Definition onnxruntime_cxx_api.h:2400
int GetVariadicOutputMinArity() const
Definition onnxruntime_cxx_api.h:2394
decltype(&C::InferOutputShape) SetShapeInferFn(decltype(&C::InferOutputShape))
Definition onnxruntime_cxx_api.h:2412
const char * GetExecutionProviderType() const
Definition onnxruntime_cxx_api.h:2363
void GetSessionConfigs(std::unordered_map< std::string, std::string > &out, ConstSessionOptions options) const
Class that represents session configuration entries for one or more custom operators.
Definition onnxruntime_cxx_api.h:821
~CustomOpConfigs()=default
CustomOpConfigs & AddConfig(const char *custom_op_name, const char *config_key, const char *config_value)
Adds a session configuration entry/value for a specific custom operator.
CustomOpConfigs & operator=(CustomOpConfigs &&o)=default
CustomOpConfigs(CustomOpConfigs &&o)=default
CustomOpConfigs()=default
const std::unordered_map< std::string, std::string > & GetFlattenedConfigs() const
Returns a flattened map of custom operator configuration entries and their values.
CustomOpConfigs(const CustomOpConfigs &)=default
CustomOpConfigs & operator=(const CustomOpConfigs &)=default
Custom Op Domain.
Definition onnxruntime_cxx_api.h:730
CustomOpDomain(std::nullptr_t)
Create an empty CustomOpDomain object, must be assigned a valid one to be used.
Definition onnxruntime_cxx_api.h:731
CustomOpDomain(const char *domain)
Wraps OrtApi::CreateCustomOpDomain.
void Add(const OrtCustomOp *op)
Wraps CustomOpDomain_Add.
The Env (Environment)
Definition onnxruntime_cxx_api.h:698
Env & EnableTelemetryEvents()
Wraps OrtApi::EnableTelemetryEvents.
Env(OrtEnv *p)
C Interop Helper.
Definition onnxruntime_cxx_api.h:715
Env & CreateAndRegisterAllocatorV2(const std::string &provider_type, const OrtMemoryInfo *mem_info, const std::unordered_map< std::string, std::string > &options, const OrtArenaCfg *arena_cfg)
Wraps OrtApi::CreateAndRegisterAllocatorV2.
Env(std::nullptr_t)
Create an empty Env object, must be assigned a valid one to be used.
Definition onnxruntime_cxx_api.h:699
Env(OrtLoggingLevel logging_level=ORT_LOGGING_LEVEL_WARNING, const char *logid="")
Wraps OrtApi::CreateEnv.
Env(const OrtThreadingOptions *tp_options, OrtLoggingLevel logging_level=ORT_LOGGING_LEVEL_WARNING, const char *logid="")
Wraps OrtApi::CreateEnvWithGlobalThreadPools.
Env(const OrtThreadingOptions *tp_options, OrtLoggingFunction logging_function, void *logger_param, OrtLoggingLevel logging_level=ORT_LOGGING_LEVEL_WARNING, const char *logid="")
Wraps OrtApi::CreateEnvWithCustomLoggerAndGlobalThreadPools.
Env(OrtLoggingLevel logging_level, const char *logid, OrtLoggingFunction logging_function, void *logger_param)
Wraps OrtApi::CreateEnvWithCustomLogger.
Env & CreateAndRegisterAllocator(const OrtMemoryInfo *mem_info, const OrtArenaCfg *arena_cfg)
Wraps OrtApi::CreateAndRegisterAllocator.
Env & UpdateEnvWithCustomLogLevel(OrtLoggingLevel log_severity_level)
Wraps OrtApi::UpdateEnvWithCustomLogLevel.
Env & DisableTelemetryEvents()
Wraps OrtApi::DisableTelemetryEvents.
All C++ methods that can fail will throw an exception of this type.
Definition onnxruntime_cxx_api.h:53
const char * what() const noexcept override
Definition onnxruntime_cxx_api.h:57
OrtErrorCode GetOrtErrorCode() const
Definition onnxruntime_cxx_api.h:56
Exception(std::string &&string, OrtErrorCode code)
Definition onnxruntime_cxx_api.h:54
IEEE 754 half-precision floating point data type.
Definition onnxruntime_cxx_api.h:164
Float16_t()=default
Default constructor.
Float16_t(float v) noexcept
__ctor from float. Float is converted into float16 16-bit representation.
Definition onnxruntime_cxx_api.h:192
onnxruntime_float16::Float16Impl< Float16_t > Base
Definition onnxruntime_cxx_api.h:174
float ToFloat() const noexcept
Converts float16 to float.
Definition onnxruntime_cxx_api.h:198
static constexpr Float16_t FromBits(uint16_t v) noexcept
Explicit conversion to uint16_t representation of float16.
Definition onnxruntime_cxx_api.h:186
float8e4m3fn (Float8 Floating Point) data type
Definition onnxruntime_cxx_api.h:436
uint8_t value
Definition onnxruntime_cxx_api.h:437
constexpr Float8E4M3FN_t(uint8_t v) noexcept
Definition onnxruntime_cxx_api.h:439
constexpr bool operator==(const Float8E4M3FN_t &rhs) const noexcept
Definition onnxruntime_cxx_api.h:442
constexpr Float8E4M3FN_t() noexcept
Definition onnxruntime_cxx_api.h:438
constexpr bool operator!=(const Float8E4M3FN_t &rhs) const noexcept
Definition onnxruntime_cxx_api.h:443
float8e4m3fnuz (Float8 Floating Point) data type
Definition onnxruntime_cxx_api.h:453
constexpr bool operator==(const Float8E4M3FNUZ_t &rhs) const noexcept
Definition onnxruntime_cxx_api.h:459
uint8_t value
Definition onnxruntime_cxx_api.h:454
constexpr Float8E4M3FNUZ_t() noexcept
Definition onnxruntime_cxx_api.h:455
constexpr bool operator!=(const Float8E4M3FNUZ_t &rhs) const noexcept
Definition onnxruntime_cxx_api.h:460
constexpr Float8E4M3FNUZ_t(uint8_t v) noexcept
Definition onnxruntime_cxx_api.h:456
float8e5m2 (Float8 Floating Point) data type
Definition onnxruntime_cxx_api.h:470
constexpr Float8E5M2_t(uint8_t v) noexcept
Definition onnxruntime_cxx_api.h:473
uint8_t value
Definition onnxruntime_cxx_api.h:471
constexpr bool operator!=(const Float8E5M2_t &rhs) const noexcept
Definition onnxruntime_cxx_api.h:477
constexpr Float8E5M2_t() noexcept
Definition onnxruntime_cxx_api.h:472
constexpr bool operator==(const Float8E5M2_t &rhs) const noexcept
Definition onnxruntime_cxx_api.h:476
float8e5m2fnuz (Float8 Floating Point) data type
Definition onnxruntime_cxx_api.h:487
constexpr Float8E5M2FNUZ_t() noexcept
Definition onnxruntime_cxx_api.h:489
constexpr Float8E5M2FNUZ_t(uint8_t v) noexcept
Definition onnxruntime_cxx_api.h:490
constexpr bool operator!=(const Float8E5M2FNUZ_t &rhs) const noexcept
Definition onnxruntime_cxx_api.h:494
constexpr bool operator==(const Float8E5M2FNUZ_t &rhs) const noexcept
Definition onnxruntime_cxx_api.h:493
uint8_t value
Definition onnxruntime_cxx_api.h:488
Definition onnxruntime_cxx_api.h:85
static const OrtApi * api_
Definition onnxruntime_cxx_api.h:86
Wrapper around OrtIoBinding.
Definition onnxruntime_cxx_api.h:1912
UnownedIoBinding GetUnowned() const
Definition onnxruntime_cxx_api.h:1916
ConstIoBinding GetConst() const
Definition onnxruntime_cxx_api.h:1915
IoBinding(Session &session)
IoBinding(std::nullptr_t)
Create an empty object for convenience. Sometimes, we want to initialize members later.
Definition onnxruntime_cxx_api.h:1913
This class wraps a raw pointer OrtKernelContext* that is being passed to the custom kernel Compute() ...
Definition onnxruntime_cxx_api.h:2105
KernelContext(OrtKernelContext *context)
Logger GetLogger() const
ConstValue GetInput(size_t index) const
OrtKernelContext * GetOrtKernelContext() const
Definition onnxruntime_cxx_api.h:2119
void ParallelFor(void(*fn)(void *, size_t), size_t total, size_t num_batch, void *usr_data) const
OrtAllocator * GetAllocator(const OrtMemoryInfo &memory_info) const
void * GetGPUComputeStream() const
size_t GetInputCount() const
size_t GetOutputCount() const
UnownedValue GetOutput(size_t index, const std::vector< int64_t > &dims) const
UnownedValue GetOutput(size_t index, const int64_t *dim_values, size_t dim_count) const
This struct owns the OrtKernInfo* pointer when a copy is made. For convenient wrapping of OrtKernelIn...
Definition onnxruntime_cxx_api.h:2185
KernelInfo(OrtKernelInfo *info)
Take ownership of the instance.
ConstKernelInfo GetConst() const
Definition onnxruntime_cxx_api.h:2188
KernelInfo(std::nullptr_t)
Create an empty instance to initialize later.
Definition onnxruntime_cxx_api.h:2186
This class represents an ONNX Runtime logger that can be used to log information with an associated s...
Definition onnxruntime_cxx_api.h:2027
Logger(Logger &&v) noexcept=default
Logger & operator=(Logger &&v) noexcept=default
Logger & operator=(const Logger &)=default
~Logger()=default
Logger(const Logger &)=default
Logger()=default
Logger(std::nullptr_t)
Definition onnxruntime_cxx_api.h:2036
Logger(const OrtLogger *logger)
OrtLoggingLevel GetLoggingSeverityLevel() const noexcept
LoraAdapter holds a set of Lora Parameters loaded from a single file.
Definition onnxruntime_cxx_api.h:741
static LoraAdapter CreateLoraAdapter(const std::basic_string< char > &adapter_path, OrtAllocator *allocator)
Wraps OrtApi::CreateLoraAdapter.
LoraAdapter(std::nullptr_t)
Definition onnxruntime_cxx_api.h:745
static LoraAdapter CreateLoraAdapterFromArray(const void *bytes, size_t num_bytes, OrtAllocator *allocator)
Wraps OrtApi::CreateLoraAdapterFromArray.
Wrapper around OrtMapTypeInfo.
Definition onnxruntime_cxx_api.h:1295
ConstMapTypeInfo GetConst() const
Definition onnxruntime_cxx_api.h:1298
MapTypeInfo(OrtMapTypeInfo *p)
Used for interop with the C API.
Definition onnxruntime_cxx_api.h:1297
MapTypeInfo(std::nullptr_t)
Create an empty MapTypeInfo object, must be assigned a valid one to be used.
Definition onnxruntime_cxx_api.h:1296
Represents native memory allocation coming from one of the OrtAllocators registered with OnnxRuntime....
Definition onnxruntime_cxx_api.h:1823
MemoryAllocation(MemoryAllocation &&) noexcept
MemoryAllocation & operator=(const MemoryAllocation &)=delete
MemoryAllocation(const MemoryAllocation &)=delete
MemoryAllocation(OrtAllocator *allocator, void *p, size_t size)
size_t size() const
Definition onnxruntime_cxx_api.h:1832
Wrapper around OrtMemoryInfo.
Definition onnxruntime_cxx_api.h:1202
MemoryInfo(const char *name, OrtAllocatorType type, int id, OrtMemType mem_type)
MemoryInfo(std::nullptr_t)
No instance is created.
Definition onnxruntime_cxx_api.h:1204
MemoryInfo(OrtMemoryInfo *p)
Take ownership of a pointer created by C Api.
Definition onnxruntime_cxx_api.h:1205
static MemoryInfo CreateCpu(OrtAllocatorType type, OrtMemType mem_type1)
ConstMemoryInfo GetConst() const
Definition onnxruntime_cxx_api.h:1207
Wrapper around OrtModelMetadata.
Definition onnxruntime_cxx_api.h:965
AllocatedStringPtr GetDescriptionAllocated(OrtAllocator *allocator) const
Returns a copy of the description.
std::vector< AllocatedStringPtr > GetCustomMetadataMapKeysAllocated(OrtAllocator *allocator) const
Returns a vector of copies of the custom metadata keys.
ModelMetadata(std::nullptr_t)
Create an empty ModelMetadata object, must be assigned a valid one to be used.
Definition onnxruntime_cxx_api.h:966
AllocatedStringPtr GetGraphDescriptionAllocated(OrtAllocator *allocator) const
Returns a copy of the graph description.
AllocatedStringPtr GetProducerNameAllocated(OrtAllocator *allocator) const
Returns a copy of the producer name.
AllocatedStringPtr GetGraphNameAllocated(OrtAllocator *allocator) const
Returns a copy of the graph name.
AllocatedStringPtr LookupCustomMetadataMapAllocated(const char *key, OrtAllocator *allocator) const
Looks up a value by a key in the Custom Metadata map.
ModelMetadata(OrtModelMetadata *p)
Used for interop with the C API.
Definition onnxruntime_cxx_api.h:967
AllocatedStringPtr GetDomainAllocated(OrtAllocator *allocator) const
Returns a copy of the domain name.
int64_t GetVersion() const
Wraps OrtApi::ModelMetadataGetVersion.
This struct provides life time management for custom op attribute.
Definition onnxruntime_cxx_api.h:1943
OpAttr(const char *name, const void *data, int len, OrtOpAttrType type)
Create and own custom defined operation.
Definition onnxruntime_cxx_api.h:2194
Op(OrtOp *)
Take ownership of the OrtOp.
static Op Create(const OrtKernelInfo *info, const char *op_name, const char *domain, int version, const char **type_constraint_names, const ONNXTensorElementDataType *type_constraint_values, size_t type_constraint_count, const OpAttr *attr_values, size_t attr_count, size_t input_count, size_t output_count)
Op(std::nullptr_t)
Create an empty Operator object, must be assigned a valid one to be used.
Definition onnxruntime_cxx_api.h:2195
void Invoke(const OrtKernelContext *context, const OrtValue *const *input_values, size_t input_count, OrtValue *const *output_values, size_t output_count)
void Invoke(const OrtKernelContext *context, const Value *input_values, size_t input_count, Value *output_values, size_t output_count)
RunOptions.
Definition onnxruntime_cxx_api.h:769
int GetRunLogSeverityLevel() const
Wraps OrtApi::RunOptionsGetRunLogSeverityLevel.
RunOptions & SetTerminate()
Terminates all currently executing Session::Run calls that were made using this RunOptions instance.
RunOptions & SetRunTag(const char *run_tag)
wraps OrtApi::RunOptionsSetRunTag
RunOptions & AddActiveLoraAdapter(const LoraAdapter &adapter)
Add the LoraAdapter to the list of active adapters. The setting does not affect RunWithBinding() call...
RunOptions & UnsetTerminate()
Clears the terminate flag so this RunOptions instance can be used in a new Session::Run call without ...
int GetRunLogVerbosityLevel() const
Wraps OrtApi::RunOptionsGetRunLogVerbosityLevel.
RunOptions(std::nullptr_t)
Create an empty RunOptions object, must be assigned a valid one to be used.
Definition onnxruntime_cxx_api.h:770
RunOptions & SetRunLogVerbosityLevel(int)
Wraps OrtApi::RunOptionsSetRunLogVerbosityLevel.
RunOptions & SetRunLogSeverityLevel(int)
Wraps OrtApi::RunOptionsSetRunLogSeverityLevel.
RunOptions & AddConfigEntry(const char *config_key, const char *config_value)
Wraps OrtApi::AddRunConfigEntry.
const char * GetRunTag() const
Wraps OrtApi::RunOptionsGetRunTag.
RunOptions()
Wraps OrtApi::CreateRunOptions.
Wrapper around OrtSequenceTypeInfo.
Definition onnxruntime_cxx_api.h:1260
SequenceTypeInfo(std::nullptr_t)
Create an empty SequenceTypeInfo object, must be assigned a valid one to be used.
Definition onnxruntime_cxx_api.h:1261
ConstSequenceTypeInfo GetConst() const
Definition onnxruntime_cxx_api.h:1263
SequenceTypeInfo(OrtSequenceTypeInfo *p)
Used for interop with the C API.
Definition onnxruntime_cxx_api.h:1262
Wrapper around OrtSession.
Definition onnxruntime_cxx_api.h:1166
Session(std::nullptr_t)
Create an empty Session object, must be assigned a valid one to be used.
Definition onnxruntime_cxx_api.h:1167
UnownedSession GetUnowned() const
Definition onnxruntime_cxx_api.h:1176
Session(const Env &env, const char *model_path, const SessionOptions &options, OrtPrepackedWeightsContainer *prepacked_weights_container)
Wraps OrtApi::CreateSessionWithPrepackedWeightsContainer.
Session(const Env &env, const void *model_data, size_t model_data_length, const SessionOptions &options, OrtPrepackedWeightsContainer *prepacked_weights_container)
Wraps OrtApi::CreateSessionFromArrayWithPrepackedWeightsContainer.
Session(const Env &env, const char *model_path, const SessionOptions &options)
Wraps OrtApi::CreateSession.
ConstSession GetConst() const
Definition onnxruntime_cxx_api.h:1175
Session(const Env &env, const void *model_data, size_t model_data_length, const SessionOptions &options)
Wraps OrtApi::CreateSessionFromArray.
Wrapper around OrtSessionOptions.
Definition onnxruntime_cxx_api.h:954
SessionOptions(std::nullptr_t)
Create an empty SessionOptions object, must be assigned a valid one to be used.
Definition onnxruntime_cxx_api.h:955
UnownedSessionOptions GetUnowned() const
Definition onnxruntime_cxx_api.h:958
SessionOptions()
Wraps OrtApi::CreateSessionOptions.
ConstSessionOptions GetConst() const
Definition onnxruntime_cxx_api.h:959
SessionOptions(OrtSessionOptions *p)
Used for interop with the C API.
Definition onnxruntime_cxx_api.h:957
Definition onnxruntime_cxx_api.h:2225
SymbolicInteger & operator=(const SymbolicInteger &)=default
SymbolicInteger(const SymbolicInteger &)=default
int64_t AsInt() const
Definition onnxruntime_cxx_api.h:2246
int64_t i_
Definition onnxruntime_cxx_api.h:2253
const char * s_
Definition onnxruntime_cxx_api.h:2254
bool operator==(const SymbolicInteger &dim) const
Definition onnxruntime_cxx_api.h:2234
SymbolicInteger & operator=(SymbolicInteger &&)=default
SymbolicInteger(SymbolicInteger &&)=default
const char * AsSym() const
Definition onnxruntime_cxx_api.h:2247
SymbolicInteger(int64_t i)
Definition onnxruntime_cxx_api.h:2226
SymbolicInteger(const char *s)
Definition onnxruntime_cxx_api.h:2227
bool IsInt() const
Definition onnxruntime_cxx_api.h:2245
Provide access to per-node attributes and input shapes, so one could compute and set output shapes.
Definition onnxruntime_cxx_api.h:2224
Ints GetAttrInts(const char *attr_name)
Strings GetAttrStrings(const char *attr_name)
Status SetOutputShape(size_t indice, const Shape &shape, ONNXTensorElementDataType type=ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)
std::vector< SymbolicInteger > Shape
Definition onnxruntime_cxx_api.h:2259
std::vector< float > Floats
Definition onnxruntime_cxx_api.h:2276
std::string GetAttrString(const char *attr_name)
std::vector< int64_t > Ints
Definition onnxruntime_cxx_api.h:2271
ShapeInferContext(const OrtApi *ort_api, OrtShapeInferContext *ctx)
int64_t GetAttrInt(const char *attr_name)
size_t GetInputCount() const
Definition onnxruntime_cxx_api.h:2265
std::vector< std::string > Strings
Definition onnxruntime_cxx_api.h:2281
Floats GetAttrFloats(const char *attr_name)
const Shape & GetInputShape(size_t indice) const
Definition onnxruntime_cxx_api.h:2263
float GetAttrFloat(const char *attr_name)
The Status that holds ownership of OrtStatus received from C API Use it to safely destroy OrtStatus* ...
Definition onnxruntime_cxx_api.h:652
OrtErrorCode GetErrorCode() const
Status(const char *message, OrtErrorCode code) noexcept
Creates status instance out of null-terminated string message.
bool IsOK() const noexcept
Returns true if instance represents an OK (non-error) status.
Status(OrtStatus *status) noexcept
Takes ownership of OrtStatus instance returned from the C API.
std::string GetErrorMessage() const
Status(const Exception &) noexcept
Creates status instance out of exception.
Status(const std::exception &) noexcept
Creates status instance out of exception.
Status(std::nullptr_t) noexcept
Create an empty object, must be assigned a valid one to be used.
Definition onnxruntime_cxx_api.h:653
Wrapper around OrtTensorTypeAndShapeInfo.
Definition onnxruntime_cxx_api.h:1239
TensorTypeAndShapeInfo(std::nullptr_t)
Create an empty TensorTypeAndShapeInfo object, must be assigned a valid one to be used.
Definition onnxruntime_cxx_api.h:1240
ConstTensorTypeAndShapeInfo GetConst() const
Definition onnxruntime_cxx_api.h:1242
TensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo *p)
Used for interop with the C API.
Definition onnxruntime_cxx_api.h:1241
The ThreadingOptions.
Definition onnxruntime_cxx_api.h:667
ThreadingOptions & SetGlobalCustomThreadCreationOptions(void *ort_custom_thread_creation_options)
Wraps OrtApi::SetGlobalCustomThreadCreationOptions.
ThreadingOptions()
Wraps OrtApi::CreateThreadingOptions.
ThreadingOptions & SetGlobalInterOpNumThreads(int inter_op_num_threads)
Wraps OrtApi::SetGlobalInterOpNumThreads.
ThreadingOptions & SetGlobalCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn)
Wraps OrtApi::SetGlobalCustomCreateThreadFn.
ThreadingOptions & SetGlobalCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn)
Wraps OrtApi::SetGlobalCustomJoinThreadFn.
ThreadingOptions & SetGlobalSpinControl(int allow_spinning)
Wraps OrtApi::SetGlobalSpinControl.
ThreadingOptions & SetGlobalDenormalAsZero()
Wraps OrtApi::SetGlobalDenormalAsZero.
ThreadingOptions & SetGlobalIntraOpNumThreads(int intra_op_num_threads)
Wraps OrtApi::SetGlobalIntraOpNumThreads.
Type information that may contain either TensorTypeAndShapeInfo or the information about contained se...
Definition onnxruntime_cxx_api.h:1326
TypeInfo(std::nullptr_t)
Create an empty TypeInfo object, must be assigned a valid one to be used.
Definition onnxruntime_cxx_api.h:1327
ConstTypeInfo GetConst() const
Definition onnxruntime_cxx_api.h:1330
TypeInfo(OrtTypeInfo *p)
C API Interop.
Definition onnxruntime_cxx_api.h:1328
Wrapper around OrtValue.
Definition onnxruntime_cxx_api.h:1662
static Value CreateSparseTensor(const OrtMemoryInfo *info, void *p_data, const Shape &dense_shape, const Shape &values_shape, ONNXTensorElementDataType type)
Creates an OrtValue instance containing SparseTensor. This constructs a sparse tensor that makes use ...
static Value CreateSparseTensor(const OrtMemoryInfo *info, T *p_data, const Shape &dense_shape, const Shape &values_shape)
This is a simple forwarding method to the other overload that helps deducing data type enum value fro...
Value & operator=(Value &&)=default
static Value CreateSparseTensor(OrtAllocator *allocator, const Shape &dense_shape, ONNXTensorElementDataType type)
Creates an instance of OrtValue containing sparse tensor. The created instance has no data....
Value(Value &&)=default
Value(std::nullptr_t)
Create an empty Value object, must be assigned a valid one to be used.
Definition onnxruntime_cxx_api.h:1667
static Value CreateTensor(const OrtMemoryInfo *info, T *p_data, size_t p_data_element_count, const int64_t *shape, size_t shape_len)
Creates a tensor with a user supplied buffer. Wraps OrtApi::CreateTensorWithDataAsOrtValue.
Value(OrtValue *p)
Used for interop with the C API.
Definition onnxruntime_cxx_api.h:1668
static Value CreateSparseTensor(OrtAllocator *allocator, const Shape &dense_shape)
This is a simple forwarding method to the below CreateSparseTensor. This helps to specify data type e...
static Value CreateTensor(OrtAllocator *allocator, const int64_t *shape, size_t shape_len, ONNXTensorElementDataType type)
Creates an OrtValue with a tensor using the supplied OrtAllocator. Wraps OrtApi::CreateTensorAsOrtVal...
UnownedValue GetUnowned() const
Definition onnxruntime_cxx_api.h:1673
static Value CreateSequence(const std::vector< Value > &values)
Creates an OrtValue with a Sequence Onnx type representation. The API would ref-count the supplied Or...
static Value CreateMap(const Value &keys, const Value &values)
Creates an OrtValue with a Map Onnx type representation. The API would ref-count the supplied OrtValu...
static Value CreateTensor(const OrtMemoryInfo *info, void *p_data, size_t p_data_byte_count, const int64_t *shape, size_t shape_len, ONNXTensorElementDataType type)
Creates a tensor with a user supplied buffer. Wraps OrtApi::CreateTensorWithDataAsOrtValue.
static Value CreateTensor(OrtAllocator *allocator, const int64_t *shape, size_t shape_len)
Creates an OrtValue with a tensor using a supplied OrtAllocator. Wraps OrtApi::CreateTensorAsOrtValue...
static Value CreateOpaque(const char *domain, const char *type_name, const T &value)
Creates an OrtValue wrapping an Opaque type. This is used for experimental support of non-tensor type...
ConstValue GetConst() const
Definition onnxruntime_cxx_api.h:1672
Definition onnxruntime_cxx_api.h:625
AllocatedFree(OrtAllocator *allocator)
Definition onnxruntime_cxx_api.h:627
OrtAllocator * allocator_
Definition onnxruntime_cxx_api.h:626
void operator()(void *ptr) const
Definition onnxruntime_cxx_api.h:629
Base & operator=(Base &&v) noexcept
Definition onnxruntime_cxx_api.h:612
typename Unowned< T >::Type contained_type
Definition onnxruntime_cxx_api.h:601
Base(Base &&v) noexcept
Definition onnxruntime_cxx_api.h:611
Base(const Base &)=default
constexpr Base(contained_type *p) noexcept
Definition onnxruntime_cxx_api.h:604
Base & operator=(const Base &)=default
Used internally by the C++ API. C++ wrapper types inherit from this. This is a zero cost abstraction ...
Definition onnxruntime_cxx_api.h:557
Base(Base &&v) noexcept
Definition onnxruntime_cxx_api.h:567
constexpr Base()=default
contained_type * release()
Relinquishes ownership of the contained C object pointer The underlying object is not destroyed.
Definition onnxruntime_cxx_api.h:578
Base(const Base &)=delete
constexpr Base(contained_type *p) noexcept
Definition onnxruntime_cxx_api.h:561
Base & operator=(const Base &)=delete
Base & operator=(Base &&v) noexcept
Definition onnxruntime_cxx_api.h:568
contained_type * p_
Definition onnxruntime_cxx_api.h:585
~Base()
Definition onnxruntime_cxx_api.h:562
T contained_type
Definition onnxruntime_cxx_api.h:558
Definition onnxruntime_cxx_api.h:1880
std::vector< Value > GetOutputValues(OrtAllocator *) const
std::vector< std::string > GetOutputNames(OrtAllocator *) const
std::vector< Value > GetOutputValues() const
std::vector< std::string > GetOutputNames() const
Definition onnxruntime_cxx_api.h:1038
TypeInfo GetInputTypeInfo(size_t index) const
Wraps OrtApi::SessionGetInputTypeInfo.
size_t GetOutputCount() const
Returns the number of model outputs.
uint64_t GetProfilingStartTimeNs() const
Wraps OrtApi::SessionGetProfilingStartTimeNs.
ModelMetadata GetModelMetadata() const
Wraps OrtApi::SessionGetModelMetadata.
size_t GetInputCount() const
Returns the number of model inputs.
TypeInfo GetOutputTypeInfo(size_t index) const
Wraps OrtApi::SessionGetOutputTypeInfo.
AllocatedStringPtr GetOverridableInitializerNameAllocated(size_t index, OrtAllocator *allocator) const
Returns a copy of the overridable initializer name at then specified index.
AllocatedStringPtr GetOutputNameAllocated(size_t index, OrtAllocator *allocator) const
Returns a copy of output name at then specified index.
size_t GetOverridableInitializerCount() const
Returns the number of inputs that have defaults that can be overridden.
AllocatedStringPtr GetInputNameAllocated(size_t index, OrtAllocator *allocator) const
Returns a copy of input name at the specified index.
TypeInfo GetOverridableInitializerTypeInfo(size_t index) const
Wraps OrtApi::SessionGetOverridableInitializerTypeInfo.
Definition onnxruntime_cxx_api.h:1359
void GetStringTensorContent(void *buffer, size_t buffer_length, size_t *offsets, size_t offsets_count) const
The API copies all of the UTF-8 encoded string data contained within a tensor or a sparse tensor into...
void GetStringTensorElement(size_t buffer_length, size_t element_index, void *buffer) const
The API copies UTF-8 encoded bytes for the requested string element contained within a tensor or a sp...
TensorTypeAndShapeInfo GetSparseTensorIndicesTypeShapeInfo(OrtSparseIndicesFormat format) const
The API returns type and shape information for the specified indices. Each supported indices have the...
const void * GetTensorRawData() const
Returns a non-typed pointer to a tensor contained data.
std::string GetStringTensorElement(size_t element_index) const
Returns string tensor UTF-8 encoded string element. Use of this API is recommended over GetStringTens...
size_t GetStringTensorElementLength(size_t element_index) const
The API returns a byte length of UTF-8 encoded string element contained in either a tensor or a spare...
size_t GetStringTensorDataLength() const
This API returns a full length of string data contained within either a tensor or a sparse Tensor....
bool IsSparseTensor() const
Returns true if the OrtValue contains a sparse tensor.
TypeInfo GetTypeInfo() const
The API returns type information for data contained in a tensor. For sparse tensors it returns type i...
const R * GetSparseTensorIndicesData(OrtSparseIndicesFormat indices_format, size_t &num_indices) const
The API retrieves a pointer to the internal indices buffer. The API merely performs a convenience dat...
bool IsTensor() const
Returns true if Value is a tensor, false for other types like map/sequence/etc.
ConstMemoryInfo GetTensorMemoryInfo() const
This API returns information about the memory allocation used to hold data.
const R * GetSparseTensorValues() const
The API returns a pointer to an internal buffer of the sparse tensor containing non-zero values....
TensorTypeAndShapeInfo GetTensorTypeAndShapeInfo() const
The API returns type information for data contained in a tensor. For sparse tensors it returns type i...
Value GetValue(int index, OrtAllocator *allocator) const
size_t GetCount() const
< Return true if OrtValue contains data and returns false if the OrtValue is a None
void GetOpaqueData(const char *domain, const char *type_name, R &) const
Obtains a pointer to a user defined data for experimental purposes.
TensorTypeAndShapeInfo GetSparseTensorValuesTypeAndShapeInfo() const
The API returns type and shape information for stored non-zero values of the sparse tensor....
const R * GetTensorData() const
Returns a const typed pointer to the tensor contained data. No type checking is performed,...
OrtSparseFormat GetSparseFormat() const
The API returns the sparse data format this OrtValue holds in a sparse tensor. If the sparse tensor w...
Definition onnxruntime_cxx_api.h:1891
void BindOutput(const char *name, const Value &)
void BindInput(const char *name, const Value &)
void BindOutput(const char *name, const OrtMemoryInfo *)
Definition onnxruntime_cxx_api.h:1281
ONNXTensorElementDataType GetMapKeyType() const
Wraps OrtApi::GetMapKeyType.
TypeInfo GetMapValueType() const
Wraps OrtApi::GetMapValueType.
Definition onnxruntime_cxx_api.h:1181
std::string GetAllocatorName() const
OrtMemType GetMemoryType() const
OrtMemoryInfoDeviceType GetDeviceType() const
OrtAllocatorType GetAllocatorType() const
bool operator==(const MemoryInfoImpl< U > &o) const
Definition onnxruntime_cxx_api.h:1268
TypeInfo GetOptionalElementType() const
Wraps OrtApi::CastOptionalTypeToContainedTypeInfo.
Definition onnxruntime_cxx_api.h:1342
const char ** str
Definition onnxruntime_cxx_api.h:1347
const int64_t * values_shape
Definition onnxruntime_cxx_api.h:1343
size_t values_shape_len
Definition onnxruntime_cxx_api.h:1344
const void * p_data
Definition onnxruntime_cxx_api.h:1346
Definition onnxruntime_cxx_api.h:1247
TypeInfo GetSequenceElementType() const
Wraps OrtApi::GetSequenceElementType.
Definition onnxruntime_cxx_api.h:1082
void SetEpDynamicOptions(const char *const *keys, const char *const *values, size_t kv_len)
Set DynamicOptions for EPs (Execution Providers)
AllocatedStringPtr EndProfilingAllocated(OrtAllocator *allocator)
End profiling and return a copy of the profiling file name.
void Run(const RunOptions &run_options, const IoBinding &)
Wraps OrtApi::RunWithBinding.
void RunAsync(const RunOptions &run_options, const char *const *input_names, const Value *input_values, size_t input_count, const char *const *output_names, Value *output_values, size_t output_count, RunAsyncCallbackFn callback, void *user_data)
Run the model asynchronously in a thread owned by intra op thread pool.
std::vector< Value > Run(const RunOptions &run_options, const char *const *input_names, const Value *input_values, size_t input_count, const char *const *output_names, size_t output_count)
Run the model returning results in an Ort allocated vector.
void Run(const RunOptions &run_options, const char *const *input_names, const Value *input_values, size_t input_count, const char *const *output_names, Value *output_values, size_t output_count)
Run the model returning results in user provided outputs Same as Run(const RunOptions&,...
Definition onnxruntime_cxx_api.h:1353
const int64_t * shape
Definition onnxruntime_cxx_api.h:1354
size_t shape_len
Definition onnxruntime_cxx_api.h:1355
Definition onnxruntime_cxx_api.h:1212
size_t GetElementCount() const
Wraps OrtApi::GetTensorShapeElementCount.
void GetDimensions(int64_t *values, size_t values_count) const
Wraps OrtApi::GetDimensions.
std::vector< int64_t > GetShape() const
Uses GetDimensionsCount & GetDimensions to return a std::vector of the shape.
void GetSymbolicDimensions(const char **values, size_t values_count) const
Wraps OrtApi::GetSymbolicDimensions.
size_t GetDimensionsCount() const
Wraps OrtApi::GetDimensionsCount.
ONNXTensorElementDataType GetElementType() const
Wraps OrtApi::GetTensorElementType.
Definition onnxruntime_cxx_api.h:1303
ONNXType GetONNXType() const
ConstSequenceTypeInfo GetSequenceTypeInfo() const
Wraps OrtApi::CastTypeInfoToSequenceTypeInfo.
ConstMapTypeInfo GetMapTypeInfo() const
Wraps OrtApi::CastTypeInfoToMapTypeInfo.
ConstOptionalTypeInfo GetOptionalTypeInfo() const
wraps OrtApi::CastTypeInfoToOptionalTypeInfo
ConstTensorTypeAndShapeInfo GetTensorTypeAndShapeInfo() const
Wraps OrtApi::CastTypeInfoToTensorInfo.
This is a tagging template type. Use it with Base<T> to indicate that the C++ interface object has no...
Definition onnxruntime_cxx_api.h:533
T Type
Definition onnxruntime_cxx_api.h:534
Definition onnxruntime_cxx_api.h:1520
void FillStringTensorElement(const char *s, size_t index)
Set a single string in a string tensor.
R * GetTensorMutableData()
Returns a non-const typed pointer to an OrtValue/Tensor contained buffer No type checking is performe...
R & At(const std::vector< int64_t > &location)
void UseBlockSparseIndices(const Shape &indices_shape, int32_t *indices_data)
Supplies BlockSparse format specific indices and marks the contained sparse tensor as being a BlockSp...
void FillSparseTensorBlockSparse(const OrtMemoryInfo *data_mem_info, const OrtSparseValuesParam &values, const Shape &indices_shape, const int32_t *indices_data)
The API will allocate memory using the allocator instance supplied to the CreateSparseTensor() API an...
void * GetTensorMutableRawData()
Returns a non-typed non-const pointer to a tensor contained data.
void UseCooIndices(int64_t *indices_data, size_t indices_num)
Supplies COO format specific indices and marks the contained sparse tensor as being a COO format tens...
void FillSparseTensorCoo(const OrtMemoryInfo *data_mem_info, const OrtSparseValuesParam &values_param, const int64_t *indices_data, size_t indices_num)
The API will allocate memory using the allocator instance supplied to the CreateSparseTensor() API an...
void FillStringTensor(const char *const *s, size_t s_len)
Set all strings at once in a string tensor.
void UseCsrIndices(int64_t *inner_data, size_t inner_num, int64_t *outer_data, size_t outer_num)
Supplies CSR format specific indices and marks the contained sparse tensor as being a CSR format tens...
void FillSparseTensorCsr(const OrtMemoryInfo *data_mem_info, const OrtSparseValuesParam &values, const int64_t *inner_indices_data, size_t inner_indices_num, const int64_t *outer_indices_data, size_t outer_indices_num)
The API will allocate memory using the allocator instance supplied to the CreateSparseTensor() API an...
char * GetResizedStringTensorElementBuffer(size_t index, size_t buffer_length)
Allocate if necessary and obtain a pointer to a UTF-8 encoded string element buffer indexed by the fl...
Memory allocation interface.
Definition onnxruntime_c_api.h:321
void(* Free)(struct OrtAllocator *this_, void *p)
Free a block of memory previously allocated with OrtAllocator::Alloc.
Definition onnxruntime_c_api.h:324
const OrtApi *(* GetApi)(uint32_t version)
Get a pointer to the requested version of the OrtApi.
Definition onnxruntime_c_api.h:677
The C API.
Definition onnxruntime_c_api.h:737
CUDA Provider Options.
Definition onnxruntime_c_api.h:411
Definition onnxruntime_c_api.h:4781
int(* GetVariadicInputHomogeneity)(const struct OrtCustomOp *op)
Definition onnxruntime_c_api.h:4827
OrtCustomOpInputOutputCharacteristic(* GetOutputCharacteristic)(const struct OrtCustomOp *op, size_t index)
Definition onnxruntime_c_api.h:4811
size_t(* GetInputTypeCount)(const struct OrtCustomOp *op)
Definition onnxruntime_c_api.h:4799
int(* GetVariadicOutputMinArity)(const struct OrtCustomOp *op)
Definition onnxruntime_c_api.h:4831
size_t(* GetAliasMap)(int **input_index, int **output_index)
Definition onnxruntime_c_api.h:4864
int(* GetStartVersion)(const struct OrtCustomOp *op)
Definition onnxruntime_c_api.h:4849
void(* ReleaseMayInplace)(int *input_index, int *output_index)
Definition onnxruntime_c_api.h:4861
const char *(* GetName)(const struct OrtCustomOp *op)
Definition onnxruntime_c_api.h:4792
size_t(* GetOutputTypeCount)(const struct OrtCustomOp *op)
Definition onnxruntime_c_api.h:4801
void(* KernelDestroy)(void *op_kernel)
Definition onnxruntime_c_api.h:4807
int(* GetVariadicOutputHomogeneity)(const struct OrtCustomOp *op)
Definition onnxruntime_c_api.h:4836
OrtMemType(* GetInputMemoryType)(const struct OrtCustomOp *op, size_t index)
Definition onnxruntime_c_api.h:4818
void *(* CreateKernel)(const struct OrtCustomOp *op, const OrtApi *api, const OrtKernelInfo *info)
Definition onnxruntime_c_api.h:4788
uint32_t version
Definition onnxruntime_c_api.h:4782
ONNXTensorElementDataType(* GetInputType)(const struct OrtCustomOp *op, size_t index)
Definition onnxruntime_c_api.h:4798
void(* ReleaseAliasMap)(int *input_index, int *output_index)
Definition onnxruntime_c_api.h:4865
OrtCustomOpInputOutputCharacteristic(* GetInputCharacteristic)(const struct OrtCustomOp *op, size_t index)
Definition onnxruntime_c_api.h:4810
const char *(* GetExecutionProviderType)(const struct OrtCustomOp *op)
Definition onnxruntime_c_api.h:4795
ONNXTensorElementDataType(* GetOutputType)(const struct OrtCustomOp *op, size_t index)
Definition onnxruntime_c_api.h:4800
int(* GetVariadicInputMinArity)(const struct OrtCustomOp *op)
Definition onnxruntime_c_api.h:4822
OrtStatusPtr(* InferOutputShapeFn)(const struct OrtCustomOp *op, OrtShapeInferContext *)
Definition onnxruntime_c_api.h:4846
int(* GetEndVersion)(const struct OrtCustomOp *op)
Definition onnxruntime_c_api.h:4850
OrtStatusPtr(* CreateKernelV2)(const struct OrtCustomOp *op, const OrtApi *api, const OrtKernelInfo *info, void **kernel)
Definition onnxruntime_c_api.h:4839
size_t(* GetMayInplace)(int **input_index, int **output_index)
Definition onnxruntime_c_api.h:4857
OrtStatusPtr(* KernelComputeV2)(void *op_kernel, OrtKernelContext *context)
Definition onnxruntime_c_api.h:4844
void(* KernelCompute)(void *op_kernel, OrtKernelContext *context)
Definition onnxruntime_c_api.h:4806
MIGraphX Provider Options.
Definition onnxruntime_c_api.h:615
OpenVINO Provider Options.
Definition onnxruntime_c_api.h:632
ROCM Provider Options.
Definition onnxruntime_c_api.h:498
TensorRT Provider Options.
Definition onnxruntime_c_api.h:587