ONNX Runtime
All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Modules Pages
Training C & C++ APIs

Training C and C++ APIs are an extension of the onnxruntime core C and C++ APIs and should be used in conjunction with them.

In order to train a model with onnxruntime, the following training artifacts must be generated:

  • The training onnx model
  • The checkpoint file
  • The optimizer onnx model
  • The eval onnx model model (optional)

These training artifacts can be generated as part of an offline step using the python utilities made available in the onnxruntime-training python package.

After these artifacts have been generated, the C and C++ utilities listed in this documentation can be leveraged to perform training.

If any problem is encountered, please create an issue with your scenario and requirements, and we will be sure to respond and follow up on the request.

Training C API

OrtTrainingApi - Training C API functions.

This C structure contains functions that enable users to perform training with onnxruntime.

Sample Code:

#include <onnxruntime_training_api.h>
OrtTrainingApi* g_ort_training_api = g_ort_api->GetTrainingApi(ORT_API_VERSION);
OrtEnv* env = NULL;
g_ort_api->CreateEnv(logging_level, logid, &env);
OrtSessionOptions* session_options = NULL;
g_ort_api->CreateSessionOptions(&session_options);
OrtCheckpointState* state = NULL;
g_ort_training_api->LoadCheckpoint(path_to_checkpoint, &state);
OrtTrainingSession* training_session = NULL;
g_ort_training_api->CreateTrainingSession(env, session_options, training_model_path,
state, eval_model_path, optimizer_model_path,
&training_session);
// Training loop
{
g_ort_training_api->TrainStep(...);
g_ort_training_api->OptimizerStep(...);
g_ort_training_api->LazyResetGrad(...);
}
g_ort_training_api->ExportModelForInferencing(training_session, inference_model_path, ...);
g_ort_training_api->SaveCheckpoint(state, path_to_checkpoint, false);
g_ort_training_api->ReleaseTrainingSession(training_session);
g_ort_training_api->ReleaseCheckpointState(state);
struct OrtSessionOptions OrtSessionOptions
Definition onnxruntime_c_api.h:292
struct OrtEnv OrtEnv
Definition onnxruntime_c_api.h:280
#define ORT_API_VERSION
The API version defined in this header.
Definition onnxruntime_c_api.h:41
const OrtApiBase * OrtGetApiBase(void)
The Onnxruntime library's entry point to access the C API.
struct OrtTrainingSession OrtTrainingSession
Definition onnxruntime_training_c_api.h:104
struct OrtCheckpointState OrtCheckpointState
Definition onnxruntime_training_c_api.h:105
const OrtApi *(* GetApi)(uint32_t version)
Get a pointer to the requested version of the OrtApi.
Definition onnxruntime_c_api.h:682
The C API.
Definition onnxruntime_c_api.h:742
OrtStatus * CreateSessionOptions(OrtSessionOptions **options)
Create an OrtSessionOptions object.
const OrtTrainingApi *(* GetTrainingApi)(uint32_t version)
Gets the Training C Api struct.
Definition onnxruntime_c_api.h:3736
OrtStatus * CreateEnv(OrtLoggingLevel log_severity_level, const char *logid, OrtEnv **out)
Create an OrtEnv.
The Training C API that holds onnxruntime training function pointers.
Definition onnxruntime_training_c_api.h:122
OrtStatus * LazyResetGrad(OrtTrainingSession *session)
Reset the gradients of all trainable parameters to zero lazily.
OrtStatus * CreateTrainingSession(const OrtEnv *env, const OrtSessionOptions *options, OrtCheckpointState *checkpoint_state, const char *train_model_path, const char *eval_model_path, const char *optimizer_model_path, OrtTrainingSession **out)
Create a training session that can be used to begin or resume training.
OrtStatus * LoadCheckpoint(const char *checkpoint_path, OrtCheckpointState **checkpoint_state)
Load a checkpoint state from a file on disk into checkpoint_state.
OrtStatus * TrainStep(OrtTrainingSession *sess, const OrtRunOptions *run_options, size_t inputs_len, const OrtValue *const *inputs, size_t outputs_len, OrtValue **outputs)
Computes the outputs of the training model and the gradients of the trainable parameters for the give...
OrtStatus * ExportModelForInferencing(OrtTrainingSession *sess, const char *inference_model_path, size_t graph_outputs_len, const char *const *graph_output_names)
Export a model that can be used for inferencing.
void ReleaseTrainingSession(OrtTrainingSession *input)
Frees up the memory used up by the training session.
void ReleaseCheckpointState(OrtCheckpointState *input)
Frees up the memory used up by the checkpoint state.
OrtStatus * SaveCheckpoint(OrtCheckpointState *checkpoint_state, const char *checkpoint_path, const bool include_optimizer_state)
Save the given state to a checkpoint file on disk.
OrtStatus * OptimizerStep(OrtTrainingSession *sess, const OrtRunOptions *run_options)
Performs the weight updates for the trainable parameters using the optimizer model.

Note The OrtCheckpointState contains the entire training state that the OrtTrainingSession uses. As a result, the training session must always have access to the state. That is to say, the OrtCheckpointState instance must outlive the lifetime of the OrtTrainingSession instance.

Training C++ API

Ort Training C++ API - Training C++ API classes and functions.

These C++ classes and functions enable users to perform training with onnxruntime.

Sample Code:

#include <onnxruntime_training_cxx_api.h>
Ort::SessionOptions session_options;
auto state = Ort::CheckpointState::LoadCheckpoint(path_to_checkpoint);
auto training_session = Ort::TrainingSession(env, session_options, state, training_model_path,
eval_model_path, optimizer_model_path);
// Training Loop
{
training_session.TrainStep(...);
training_session.OptimizerStep(...);
training_session.LazyResetGrad(...);
}
training_session->ExportModelForInferencing(inference_model_path, ...);
Ort::CheckpointState::SaveCheckpoint(state, path_to_checkpoint, false);
static CheckpointState LoadCheckpoint(const std::basic_string< char > &path_to_checkpoint)
Load a checkpoint state from a file on disk into checkpoint_state.
static void SaveCheckpoint(const CheckpointState &checkpoint_state, const std::basic_string< char > &path_to_checkpoint, const bool include_optimizer_state=false)
Save the given state to a checkpoint file on disk.
Trainer class that provides training, evaluation and optimizer methods for training an ONNX models.
Definition onnxruntime_training_cxx_api.h:180
The Env (Environment)
Definition onnxruntime_cxx_api.h:701
Wrapper around OrtSessionOptions.
Definition onnxruntime_cxx_api.h:960

Note The Ort::CheckpointState contains the entire training state that the Ort::TrainingSession uses. As a result, the training session must always have access to the state. That is to say, the Ort::CheckpointState instance must outlive the lifetime of the Ort::TrainingSession instance.