ONNX Runtime
Loading...
Searching...
No Matches
Ort::KernelContext Struct Reference

This class wraps a raw pointer OrtKernelContext* that is being passed to the custom kernel Compute() method. Use it to safely access context attributes, input and output parameters with exception safety guarantees. See usage example in onnxruntime/test/testdata/custom_op_library/custom_op_library.cc. More...

#include <onnxruntime_cxx_api.h>

Public Member Functions

 KernelContext (OrtKernelContext *context)
 
size_t GetInputCount () const
 
size_t GetOutputCount () const
 
ConstValue GetInput (size_t index) const
 
UnownedValue GetOutput (size_t index, const int64_t *dim_values, size_t dim_count) const
 
UnownedValue GetOutput (size_t index, const std::vector< int64_t > &dims) const
 
void * GetGPUComputeStream () const
 
Logger GetLogger () const
 
OrtAllocatorGetAllocator (const OrtMemoryInfo &memory_info) const
 
OrtKernelContextGetOrtKernelContext () const
 
void ParallelFor (void(*fn)(void *, size_t), size_t total, size_t num_batch, void *usr_data) const
 

Detailed Description

This class wraps a raw pointer OrtKernelContext* that is being passed to the custom kernel Compute() method. Use it to safely access context attributes, input and output parameters with exception safety guarantees. See usage example in onnxruntime/test/testdata/custom_op_library/custom_op_library.cc.

Constructor & Destructor Documentation

◆ KernelContext()

Ort::KernelContext::KernelContext ( OrtKernelContext context)
explicit

Member Function Documentation

◆ GetAllocator()

OrtAllocator * Ort::KernelContext::GetAllocator ( const OrtMemoryInfo memory_info) const

◆ GetGPUComputeStream()

void * Ort::KernelContext::GetGPUComputeStream ( ) const

◆ GetInput()

ConstValue Ort::KernelContext::GetInput ( size_t  index) const

◆ GetInputCount()

size_t Ort::KernelContext::GetInputCount ( ) const

◆ GetLogger()

Logger Ort::KernelContext::GetLogger ( ) const

◆ GetOrtKernelContext()

OrtKernelContext * Ort::KernelContext::GetOrtKernelContext ( ) const
inline

◆ GetOutput() [1/2]

UnownedValue Ort::KernelContext::GetOutput ( size_t  index,
const int64_t *  dim_values,
size_t  dim_count 
) const

◆ GetOutput() [2/2]

UnownedValue Ort::KernelContext::GetOutput ( size_t  index,
const std::vector< int64_t > &  dims 
) const

◆ GetOutputCount()

size_t Ort::KernelContext::GetOutputCount ( ) const

◆ ParallelFor()

void Ort::KernelContext::ParallelFor ( void(*)(void *, size_t)  fn,
size_t  total,
size_t  num_batch,
void *  usr_data 
) const